diff options
Diffstat (limited to 'net/socket.c')
| -rw-r--r-- | net/socket.c | 1184 |
1 files changed, 665 insertions, 519 deletions
diff --git a/net/socket.c b/net/socket.c index 50cf75730fd7..136b98c54fb3 100644 --- a/net/socket.c +++ b/net/socket.c @@ -57,6 +57,7 @@ #include <linux/mm.h> #include <linux/socket.h> #include <linux/file.h> +#include <linux/splice.h> #include <linux/net.h> #include <linux/interrupt.h> #include <linux/thread_info.h> @@ -87,6 +88,7 @@ #include <linux/xattr.h> #include <linux/nospec.h> #include <linux/indirect_call_wrapper.h> +#include <linux/io_uring/net.h> #include <linux/uaccess.h> #include <asm/unistd.h> @@ -106,6 +108,9 @@ #include <net/busy_poll.h> #include <linux/errqueue.h> #include <linux/ptp_clock_kernel.h> +#include <trace/events/sock.h> + +#include "core/dev.h" #ifdef CONFIG_NET_RX_BUSY_POLL unsigned int sysctl_net_busy_read __read_mostly; @@ -125,19 +130,19 @@ static long compat_sock_ioctl(struct file *file, unsigned int cmd, unsigned long arg); #endif static int sock_fasync(int fd, struct file *filp, int on); -static ssize_t sock_sendpage(struct file *file, struct page *page, - int offset, size_t size, loff_t *ppos, int more); static ssize_t sock_splice_read(struct file *file, loff_t *ppos, struct pipe_inode_info *pipe, size_t len, unsigned int flags); +static void sock_splice_eof(struct file *file); #ifdef CONFIG_PROC_FS static void sock_show_fdinfo(struct seq_file *m, struct file *f) { struct socket *sock = f->private_data; + const struct proto_ops *ops = READ_ONCE(sock->ops); - if (sock->ops->show_fdinfo) - sock->ops->show_fdinfo(m, sock); + if (ops->show_fdinfo) + ops->show_fdinfo(m, sock); } #else #define sock_show_fdinfo NULL @@ -150,7 +155,6 @@ static void sock_show_fdinfo(struct seq_file *m, struct file *f) static const struct file_operations socket_file_ops = { .owner = THIS_MODULE, - .llseek = no_llseek, .read_iter = sock_read_iter, .write_iter = sock_write_iter, .poll = sock_poll, @@ -158,12 +162,13 @@ static const struct file_operations socket_file_ops = { #ifdef CONFIG_COMPAT .compat_ioctl = compat_sock_ioctl, #endif + .uring_cmd = io_uring_cmd_sock, .mmap = sock_mmap, .release = sock_close, .fasync = sock_fasync, - .sendpage = sock_sendpage, - .splice_write = generic_splice_sendpage, + .splice_write = splice_to_socket, .splice_read = sock_splice_read, + .splice_eof = sock_splice_eof, .show_fdinfo = sock_show_fdinfo, }; @@ -271,28 +276,41 @@ int move_addr_to_kernel(void __user *uaddr, int ulen, struct sockaddr_storage *k static int move_addr_to_user(struct sockaddr_storage *kaddr, int klen, void __user *uaddr, int __user *ulen) { - int err; int len; BUG_ON(klen > sizeof(struct sockaddr_storage)); - err = get_user(len, ulen); - if (err) - return err; + + if (can_do_masked_user_access()) + ulen = masked_user_access_begin(ulen); + else if (!user_access_begin(ulen, 4)) + return -EFAULT; + + unsafe_get_user(len, ulen, efault_end); + if (len > klen) len = klen; - if (len < 0) - return -EINVAL; + /* + * "fromlen shall refer to the value before truncation.." + * 1003.1g + */ + if (len >= 0) + unsafe_put_user(klen, ulen, efault_end); + + user_access_end(); + if (len) { + if (len < 0) + return -EINVAL; if (audit_sockaddr(klen, kaddr)) return -ENOMEM; if (copy_to_user(uaddr, kaddr, len)) return -EFAULT; } - /* - * "fromlen shall refer to the value before truncation.." - * 1003.1g - */ - return __put_user(klen, ulen); + return 0; + +efault_end: + user_access_end(); + return -EFAULT; } static struct kmem_cache *sock_inode_cachep __ro_after_init; @@ -301,7 +319,7 @@ static struct inode *sock_alloc_inode(struct super_block *sb) { struct socket_alloc *ei; - ei = kmem_cache_alloc(sock_inode_cachep, GFP_KERNEL); + ei = alloc_inode_sb(sb, sock_inode_cachep, GFP_KERNEL); if (!ei) return NULL; init_waitqueue_head(&ei->socket.wq.wait); @@ -339,7 +357,7 @@ static void init_inodecache(void) 0, (SLAB_HWCACHE_ALIGN | SLAB_RECLAIM_ACCOUNT | - SLAB_MEM_SPREAD | SLAB_ACCOUNT), + SLAB_ACCOUNT), init_once); BUG_ON(sock_inode_cachep == NULL); } @@ -355,7 +373,7 @@ static const struct super_operations sockfs_ops = { */ static char *sockfs_dname(struct dentry *dentry, char *buffer, int buflen) { - return dynamic_dname(dentry, buffer, buflen, "socket:[%lu]", + return dynamic_dname(buffer, buflen, "socket:[%lu]", d_inode(dentry)->i_ino); } @@ -385,7 +403,7 @@ static const struct xattr_handler sockfs_xattr_handler = { }; static int sockfs_security_xattr_set(const struct xattr_handler *handler, - struct user_namespace *mnt_userns, + struct mnt_idmap *idmap, struct dentry *dentry, struct inode *inode, const char *suffix, const void *value, size_t size, int flags) @@ -399,7 +417,7 @@ static const struct xattr_handler sockfs_security_xattr_handler = { .set = sockfs_security_xattr_set, }; -static const struct xattr_handler *sockfs_xattr_handlers[] = { +static const struct xattr_handler * const sockfs_xattr_handlers[] = { &sockfs_xattr_handler, &sockfs_security_xattr_handler, NULL @@ -449,7 +467,9 @@ static struct file_system_type sock_fs_type = { * * Returns the &file bound with @sock, implicitly storing it * in sock->file. If dname is %NULL, sets to "". - * On failure the return is a ERR pointer (see linux/err.h). + * + * On failure @sock is released, and an ERR pointer is returned. + * * This function uses GFP_KERNEL internally. */ @@ -468,9 +488,15 @@ struct file *sock_alloc_file(struct socket *sock, int flags, const char *dname) return file; } + file->f_mode |= FMODE_NOWAIT; sock->file = file; file->private_data = sock; stream_open(SOCK_INODE(sock), file); + /* + * Disable permission and pre-content events, but enable legacy + * inotify events for legacy users. + */ + file_set_fsnotify_mode(file, FMODE_NONOTIFY_PERM); return file; } EXPORT_SYMBOL(sock_alloc_file); @@ -503,8 +529,8 @@ static int sock_map_fd(struct socket *sock, int flags) struct socket *sock_from_file(struct file *file) { - if (file->f_op == &socket_file_ops) - return file->private_data; /* set in sock_map_fd */ + if (likely(file->f_op == &socket_file_ops)) + return file->private_data; /* set in sock_alloc_file */ return NULL; } @@ -543,24 +569,6 @@ struct socket *sockfd_lookup(int fd, int *err) } EXPORT_SYMBOL(sockfd_lookup); -static struct socket *sockfd_lookup_light(int fd, int *err, int *fput_needed) -{ - struct fd f = fdget(fd); - struct socket *sock; - - *err = -EBADF; - if (f.file) { - sock = sock_from_file(f.file); - if (likely(sock)) { - *fput_needed = f.flags & FDPUT_FPUT; - return sock; - } - *err = -ENOTSOCK; - fdput(f); - } - return NULL; -} - static ssize_t sockfs_listxattr(struct dentry *dentry, char *buffer, size_t size) { @@ -589,18 +597,20 @@ static ssize_t sockfs_listxattr(struct dentry *dentry, char *buffer, return used; } -static int sockfs_setattr(struct user_namespace *mnt_userns, +static int sockfs_setattr(struct mnt_idmap *idmap, struct dentry *dentry, struct iattr *iattr) { - int err = simple_setattr(&init_user_ns, dentry, iattr); + int err = simple_setattr(&nop_mnt_idmap, dentry, iattr); if (!err && (iattr->ia_valid & ATTR_UID)) { struct socket *sock = SOCKET_I(d_inode(dentry)); - if (sock->sk) - sock->sk->sk_uid = iattr->ia_uid; - else + if (sock->sk) { + /* Paired with READ_ONCE() in sk_uid() */ + WRITE_ONCE(sock->sk->sk_uid, iattr->ia_uid); + } else { err = -ENOENT; + } } return err; @@ -642,12 +652,14 @@ EXPORT_SYMBOL(sock_alloc); static void __sock_release(struct socket *sock, struct inode *inode) { - if (sock->ops) { - struct module *owner = sock->ops->owner; + const struct proto_ops *ops = READ_ONCE(sock->ops); + + if (ops) { + struct module *owner = ops->owner; if (inode) inode_lock(inode); - sock->ops->release(sock); + ops->release(sock); sock->sk = NULL; if (inode) inode_unlock(inode); @@ -679,12 +691,12 @@ void sock_release(struct socket *sock) } EXPORT_SYMBOL(sock_release); -void __sock_tx_timestamp(__u16 tsflags, __u8 *tx_flags) +void __sock_tx_timestamp(__u32 tsflags, __u8 *tx_flags) { u8 flags = *tx_flags; if (tsflags & SOF_TIMESTAMPING_TX_HARDWARE) - flags |= SKBTX_HW_TSTAMP; + flags |= SKBTX_HW_TSTAMP_NOBPF; if (tsflags & SOF_TIMESTAMPING_TX_SOFTWARE) flags |= SKBTX_SW_TSTAMP; @@ -692,6 +704,9 @@ void __sock_tx_timestamp(__u16 tsflags, __u8 *tx_flags) if (tsflags & SOF_TIMESTAMPING_TX_SCHED) flags |= SKBTX_SCHED_TSTAMP; + if (tsflags & SOF_TIMESTAMPING_TX_COMPLETION) + flags |= SKBTX_COMPLETION_TSTAMP; + *tx_flags = flags; } EXPORT_SYMBOL(__sock_tx_timestamp); @@ -700,15 +715,33 @@ INDIRECT_CALLABLE_DECLARE(int inet_sendmsg(struct socket *, struct msghdr *, size_t)); INDIRECT_CALLABLE_DECLARE(int inet6_sendmsg(struct socket *, struct msghdr *, size_t)); + +static noinline void call_trace_sock_send_length(struct sock *sk, int ret, + int flags) +{ + trace_sock_send_length(sk, ret, 0); +} + static inline int sock_sendmsg_nosec(struct socket *sock, struct msghdr *msg) { - int ret = INDIRECT_CALL_INET(sock->ops->sendmsg, inet6_sendmsg, + int ret = INDIRECT_CALL_INET(READ_ONCE(sock->ops)->sendmsg, inet6_sendmsg, inet_sendmsg, sock, msg, msg_data_left(msg)); BUG_ON(ret == -EIOCBQUEUED); + + if (trace_sock_send_length_enabled()) + call_trace_sock_send_length(sock->sk, ret, 0); return ret; } +static int __sock_sendmsg(struct socket *sock, struct msghdr *msg) +{ + int err = security_socket_sendmsg(sock, msg, + msg_data_left(msg)); + + return err ?: sock_sendmsg_nosec(sock, msg); +} + /** * sock_sendmsg - send a message through @sock * @sock: socket @@ -719,10 +752,21 @@ static inline int sock_sendmsg_nosec(struct socket *sock, struct msghdr *msg) */ int sock_sendmsg(struct socket *sock, struct msghdr *msg) { - int err = security_socket_sendmsg(sock, msg, - msg_data_left(msg)); + struct sockaddr_storage *save_addr = (struct sockaddr_storage *)msg->msg_name; + struct sockaddr_storage address; + int save_len = msg->msg_namelen; + int ret; - return err ?: sock_sendmsg_nosec(sock, msg); + if (msg->msg_name) { + memcpy(&address, msg->msg_name, msg->msg_namelen); + msg->msg_name = &address; + } + + ret = __sock_sendmsg(sock, msg); + msg->msg_name = save_addr; + msg->msg_namelen = save_len; + + return ret; } EXPORT_SYMBOL(sock_sendmsg); @@ -741,38 +785,11 @@ EXPORT_SYMBOL(sock_sendmsg); int kernel_sendmsg(struct socket *sock, struct msghdr *msg, struct kvec *vec, size_t num, size_t size) { - iov_iter_kvec(&msg->msg_iter, WRITE, vec, num, size); + iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, num, size); return sock_sendmsg(sock, msg); } EXPORT_SYMBOL(kernel_sendmsg); -/** - * kernel_sendmsg_locked - send a message through @sock (kernel-space) - * @sk: sock - * @msg: message header - * @vec: output s/g array - * @num: output s/g array length - * @size: total message data size - * - * Builds the message data with @vec and sends it through @sock. - * Returns the number of bytes sent, or an error code. - * Caller must hold @sk. - */ - -int kernel_sendmsg_locked(struct sock *sk, struct msghdr *msg, - struct kvec *vec, size_t num, size_t size) -{ - struct socket *sock = sk->sk_socket; - - if (!sock->ops->sendmsg_locked) - return sock_no_sendmsg_locked(sk, msg, size); - - iov_iter_kvec(&msg->msg_iter, WRITE, vec, num, size); - - return sock->ops->sendmsg_locked(sk, msg, msg_data_left(msg)); -} -EXPORT_SYMBOL(kernel_sendmsg_locked); - static bool skb_is_err_queue(const struct sk_buff *skb) { /* pkt_type of skbs enqueued on the error queue are set to @@ -796,7 +813,28 @@ static bool skb_is_swtx_tstamp(const struct sk_buff *skb, int false_tstamp) return skb->tstamp && !false_tstamp && skb_is_err_queue(skb); } -static void put_ts_pktinfo(struct msghdr *msg, struct sk_buff *skb) +static ktime_t get_timestamp(struct sock *sk, struct sk_buff *skb, int *if_index) +{ + bool cycles = READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_BIND_PHC; + struct skb_shared_hwtstamps *shhwtstamps = skb_hwtstamps(skb); + struct net_device *orig_dev; + ktime_t hwtstamp; + + rcu_read_lock(); + orig_dev = dev_get_by_napi_id(skb_napi_id(skb)); + if (orig_dev) { + *if_index = orig_dev->ifindex; + hwtstamp = netdev_get_tstamp(orig_dev, shhwtstamps, cycles); + } else { + hwtstamp = shhwtstamps->hwtstamp; + } + rcu_read_unlock(); + + return hwtstamp; +} + +static void put_ts_pktinfo(struct msghdr *msg, struct sk_buff *skb, + int if_index) { struct scm_ts_pktinfo ts_pktinfo; struct net_device *orig_dev; @@ -806,17 +844,66 @@ static void put_ts_pktinfo(struct msghdr *msg, struct sk_buff *skb) memset(&ts_pktinfo, 0, sizeof(ts_pktinfo)); - rcu_read_lock(); - orig_dev = dev_get_by_napi_id(skb_napi_id(skb)); - if (orig_dev) - ts_pktinfo.if_index = orig_dev->ifindex; - rcu_read_unlock(); + if (!if_index) { + rcu_read_lock(); + orig_dev = dev_get_by_napi_id(skb_napi_id(skb)); + if (orig_dev) + if_index = orig_dev->ifindex; + rcu_read_unlock(); + } + ts_pktinfo.if_index = if_index; ts_pktinfo.pkt_length = skb->len - skb_mac_offset(skb); put_cmsg(msg, SOL_SOCKET, SCM_TIMESTAMPING_PKTINFO, sizeof(ts_pktinfo), &ts_pktinfo); } +bool skb_has_tx_timestamp(struct sk_buff *skb, const struct sock *sk) +{ + const struct sock_exterr_skb *serr = SKB_EXT_ERR(skb); + u32 tsflags = READ_ONCE(sk->sk_tsflags); + + if (serr->ee.ee_errno != ENOMSG || + serr->ee.ee_origin != SO_EE_ORIGIN_TIMESTAMPING) + return false; + + /* software time stamp available and wanted */ + if ((tsflags & SOF_TIMESTAMPING_SOFTWARE) && skb->tstamp) + return true; + /* hardware time stamps available and wanted */ + return (tsflags & SOF_TIMESTAMPING_RAW_HARDWARE) && + skb_hwtstamps(skb)->hwtstamp; +} + +int skb_get_tx_timestamp(struct sk_buff *skb, struct sock *sk, + struct timespec64 *ts) +{ + u32 tsflags = READ_ONCE(sk->sk_tsflags); + ktime_t hwtstamp; + int if_index = 0; + + if ((tsflags & SOF_TIMESTAMPING_SOFTWARE) && + ktime_to_timespec64_cond(skb->tstamp, ts)) + return SOF_TIMESTAMPING_TX_SOFTWARE; + + if (!(tsflags & SOF_TIMESTAMPING_RAW_HARDWARE) || + skb_is_swtx_tstamp(skb, false)) + return -ENOENT; + + if (skb_shinfo(skb)->tx_flags & SKBTX_HW_TSTAMP_NETDEV) + hwtstamp = get_timestamp(sk, skb, &if_index); + else + hwtstamp = skb_hwtstamps(skb)->hwtstamp; + + if (tsflags & SOF_TIMESTAMPING_BIND_PHC) + hwtstamp = ptp_convert_timestamp(&hwtstamp, + READ_ONCE(sk->sk_bind_phc)); + if (!ktime_to_timespec64_cond(hwtstamp, ts)) + return -ENOENT; + + return SOF_TIMESTAMPING_TX_HARDWARE; +} + /* * called from sock_recv_timestamp() if sock_flag(sk, SOCK_RCVTSTAMP) */ @@ -826,11 +913,12 @@ void __sock_recv_timestamp(struct msghdr *msg, struct sock *sk, int need_software_tstamp = sock_flag(sk, SOCK_RCVTSTAMP); int new_tstamp = sock_flag(sk, SOCK_TSTAMP_NEW); struct scm_timestamping_internal tss; - int empty = 1, false_tstamp = 0; struct skb_shared_hwtstamps *shhwtstamps = skb_hwtstamps(skb); + int if_index; ktime_t hwtstamp; + u32 tsflags; /* Race occurred between timestamp enabling and packet receiving. Fill in the current time for now. */ @@ -872,24 +960,35 @@ void __sock_recv_timestamp(struct msghdr *msg, struct sock *sk, } memset(&tss, 0, sizeof(tss)); - if ((sk->sk_tsflags & SOF_TIMESTAMPING_SOFTWARE) && + tsflags = READ_ONCE(sk->sk_tsflags); + if ((tsflags & SOF_TIMESTAMPING_SOFTWARE && + (tsflags & SOF_TIMESTAMPING_RX_SOFTWARE || + skb_is_err_queue(skb) || + !(tsflags & SOF_TIMESTAMPING_OPT_RX_FILTER))) && ktime_to_timespec64_cond(skb->tstamp, tss.ts + 0)) empty = 0; if (shhwtstamps && - (sk->sk_tsflags & SOF_TIMESTAMPING_RAW_HARDWARE) && + (tsflags & SOF_TIMESTAMPING_RAW_HARDWARE && + (tsflags & SOF_TIMESTAMPING_RX_HARDWARE || + skb_is_err_queue(skb) || + !(tsflags & SOF_TIMESTAMPING_OPT_RX_FILTER))) && !skb_is_swtx_tstamp(skb, false_tstamp)) { - if (sk->sk_tsflags & SOF_TIMESTAMPING_BIND_PHC) - hwtstamp = ptp_convert_timestamp(shhwtstamps, - sk->sk_bind_phc); + if_index = 0; + if (skb_shinfo(skb)->tx_flags & SKBTX_HW_TSTAMP_NETDEV) + hwtstamp = get_timestamp(sk, skb, &if_index); else hwtstamp = shhwtstamps->hwtstamp; + if (tsflags & SOF_TIMESTAMPING_BIND_PHC) + hwtstamp = ptp_convert_timestamp(&hwtstamp, + READ_ONCE(sk->sk_bind_phc)); + if (ktime_to_timespec64_cond(hwtstamp, tss.ts + 2)) { empty = 0; - if ((sk->sk_tsflags & SOF_TIMESTAMPING_OPT_PKTINFO) && + if ((tsflags & SOF_TIMESTAMPING_OPT_PKTINFO) && !skb_is_err_queue(skb)) - put_ts_pktinfo(msg, skb); + put_ts_pktinfo(msg, skb, if_index); } } if (!empty) { @@ -906,6 +1005,7 @@ void __sock_recv_timestamp(struct msghdr *msg, struct sock *sk, } EXPORT_SYMBOL_GPL(__sock_recv_timestamp); +#ifdef CONFIG_WIRELESS void __sock_recv_wifi_status(struct msghdr *msg, struct sock *sk, struct sk_buff *skb) { @@ -921,6 +1021,7 @@ void __sock_recv_wifi_status(struct msghdr *msg, struct sock *sk, put_cmsg(msg, SOL_SOCKET, SCM_WIFI_STATUS, sizeof(ack), &ack); } EXPORT_SYMBOL_GPL(__sock_recv_wifi_status); +#endif static inline void sock_recv_drops(struct msghdr *msg, struct sock *sk, struct sk_buff *skb) @@ -930,24 +1031,57 @@ static inline void sock_recv_drops(struct msghdr *msg, struct sock *sk, sizeof(__u32), &SOCK_SKB_CB(skb)->dropcount); } -void __sock_recv_ts_and_drops(struct msghdr *msg, struct sock *sk, - struct sk_buff *skb) +static void sock_recv_mark(struct msghdr *msg, struct sock *sk, + struct sk_buff *skb) +{ + if (sock_flag(sk, SOCK_RCVMARK) && skb) { + /* We must use a bounce buffer for CONFIG_HARDENED_USERCOPY=y */ + __u32 mark = skb->mark; + + put_cmsg(msg, SOL_SOCKET, SO_MARK, sizeof(__u32), &mark); + } +} + +static void sock_recv_priority(struct msghdr *msg, struct sock *sk, + struct sk_buff *skb) +{ + if (sock_flag(sk, SOCK_RCVPRIORITY) && skb) { + __u32 priority = skb->priority; + + put_cmsg(msg, SOL_SOCKET, SO_PRIORITY, sizeof(__u32), &priority); + } +} + +void __sock_recv_cmsgs(struct msghdr *msg, struct sock *sk, + struct sk_buff *skb) { sock_recv_timestamp(msg, sk, skb); sock_recv_drops(msg, sk, skb); + sock_recv_mark(msg, sk, skb); + sock_recv_priority(msg, sk, skb); } -EXPORT_SYMBOL_GPL(__sock_recv_ts_and_drops); +EXPORT_SYMBOL_GPL(__sock_recv_cmsgs); INDIRECT_CALLABLE_DECLARE(int inet_recvmsg(struct socket *, struct msghdr *, size_t, int)); INDIRECT_CALLABLE_DECLARE(int inet6_recvmsg(struct socket *, struct msghdr *, size_t, int)); + +static noinline void call_trace_sock_recv_length(struct sock *sk, int ret, int flags) +{ + trace_sock_recv_length(sk, ret, flags); +} + static inline int sock_recvmsg_nosec(struct socket *sock, struct msghdr *msg, int flags) { - return INDIRECT_CALL_INET(sock->ops->recvmsg, inet6_recvmsg, - inet_recvmsg, sock, msg, msg_data_left(msg), - flags); + int ret = INDIRECT_CALL_INET(READ_ONCE(sock->ops)->recvmsg, + inet6_recvmsg, + inet_recvmsg, sock, msg, + msg_data_left(msg), flags); + if (trace_sock_recv_length_enabled()) + call_trace_sock_recv_length(sock->sk, ret, flags); + return ret; } /** @@ -987,36 +1121,33 @@ int kernel_recvmsg(struct socket *sock, struct msghdr *msg, struct kvec *vec, size_t num, size_t size, int flags) { msg->msg_control_is_user = false; - iov_iter_kvec(&msg->msg_iter, READ, vec, num, size); + iov_iter_kvec(&msg->msg_iter, ITER_DEST, vec, num, size); return sock_recvmsg(sock, msg, flags); } EXPORT_SYMBOL(kernel_recvmsg); -static ssize_t sock_sendpage(struct file *file, struct page *page, - int offset, size_t size, loff_t *ppos, int more) +static ssize_t sock_splice_read(struct file *file, loff_t *ppos, + struct pipe_inode_info *pipe, size_t len, + unsigned int flags) { - struct socket *sock; - int flags; - - sock = file->private_data; + struct socket *sock = file->private_data; + const struct proto_ops *ops; - flags = (file->f_flags & O_NONBLOCK) ? MSG_DONTWAIT : 0; - /* more is a combination of MSG_MORE and MSG_SENDPAGE_NOTLAST */ - flags |= more; + ops = READ_ONCE(sock->ops); + if (unlikely(!ops->splice_read)) + return copy_splice_read(file, ppos, pipe, len, flags); - return kernel_sendpage(sock, page, offset, size, flags); + return ops->splice_read(sock, ppos, pipe, len, flags); } -static ssize_t sock_splice_read(struct file *file, loff_t *ppos, - struct pipe_inode_info *pipe, size_t len, - unsigned int flags) +static void sock_splice_eof(struct file *file) { struct socket *sock = file->private_data; + const struct proto_ops *ops; - if (unlikely(!sock->ops->splice_read)) - return generic_file_splice_read(file, ppos, pipe, len, flags); - - return sock->ops->splice_read(sock, ppos, pipe, len, flags); + ops = READ_ONCE(sock->ops); + if (ops->splice_eof) + ops->splice_eof(sock); } static ssize_t sock_read_iter(struct kiocb *iocb, struct iov_iter *to) @@ -1058,7 +1189,10 @@ static ssize_t sock_write_iter(struct kiocb *iocb, struct iov_iter *from) if (sock->type == SOCK_SEQPACKET) msg.msg_flags |= MSG_EOR; - res = sock_sendmsg(sock, &msg); + if (iocb->ki_flags & IOCB_NOSIGNAL) + msg.msg_flags |= MSG_NOSIGNAL; + + res = __sock_sendmsg(sock, &msg); *from = msg.msg_iter; return res; } @@ -1069,12 +1203,10 @@ static ssize_t sock_write_iter(struct kiocb *iocb, struct iov_iter *from) */ static DEFINE_MUTEX(br_ioctl_mutex); -static int (*br_ioctl_hook)(struct net *net, struct net_bridge *br, - unsigned int cmd, struct ifreq *ifr, +static int (*br_ioctl_hook)(struct net *net, unsigned int cmd, void __user *uarg); -void brioctl_set(int (*hook)(struct net *net, struct net_bridge *br, - unsigned int cmd, struct ifreq *ifr, +void brioctl_set(int (*hook)(struct net *net, unsigned int cmd, void __user *uarg)) { mutex_lock(&br_ioctl_mutex); @@ -1083,8 +1215,7 @@ void brioctl_set(int (*hook)(struct net *net, struct net_bridge *br, } EXPORT_SYMBOL(brioctl_set); -int br_ioctl_call(struct net *net, struct net_bridge *br, unsigned int cmd, - struct ifreq *ifr, void __user *uarg) +int br_ioctl_call(struct net *net, unsigned int cmd, void __user *uarg) { int err = -ENOPKG; @@ -1093,7 +1224,7 @@ int br_ioctl_call(struct net *net, struct net_bridge *br, unsigned int cmd, mutex_lock(&br_ioctl_mutex); if (br_ioctl_hook) - err = br_ioctl_hook(net, br, cmd, ifr, uarg); + err = br_ioctl_hook(net, cmd, uarg); mutex_unlock(&br_ioctl_mutex); return err; @@ -1113,13 +1244,14 @@ EXPORT_SYMBOL(vlan_ioctl_set); static long sock_do_ioctl(struct net *net, struct socket *sock, unsigned int cmd, unsigned long arg) { + const struct proto_ops *ops = READ_ONCE(sock->ops); struct ifreq ifr; bool need_copyout; int err; void __user *argp = (void __user *)arg; void __user *data; - err = sock->ops->ioctl(sock, cmd, arg); + err = ops->ioctl(sock, cmd, arg); /* * If this ioctl is unknown try to hand it down @@ -1148,6 +1280,7 @@ static long sock_do_ioctl(struct net *net, struct socket *sock, static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg) { + const struct proto_ops *ops; struct socket *sock; struct sock *sk; void __user *argp = (void __user *)arg; @@ -1155,6 +1288,7 @@ static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg) struct net *net; sock = file->private_data; + ops = READ_ONCE(sock->ops); sk = sock->sk; net = sock_net(sk); if (unlikely(cmd >= SIOCDEVPRIVATE && cmd <= (SIOCDEVPRIVATE + 15))) { @@ -1190,7 +1324,9 @@ static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg) case SIOCSIFBR: case SIOCBRADDBR: case SIOCBRDELBR: - err = br_ioctl_call(net, NULL, cmd, NULL, argp); + case SIOCBRADDIF: + case SIOCBRDELIF: + err = br_ioctl_call(net, cmd, argp); break; case SIOCGIFVLAN: case SIOCSIFVLAN: @@ -1212,23 +1348,23 @@ static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg) break; case SIOCGSTAMP_OLD: case SIOCGSTAMPNS_OLD: - if (!sock->ops->gettstamp) { + if (!ops->gettstamp) { err = -ENOIOCTLCMD; break; } - err = sock->ops->gettstamp(sock, argp, - cmd == SIOCGSTAMP_OLD, - !IS_ENABLED(CONFIG_64BIT)); + err = ops->gettstamp(sock, argp, + cmd == SIOCGSTAMP_OLD, + !IS_ENABLED(CONFIG_64BIT)); break; case SIOCGSTAMP_NEW: case SIOCGSTAMPNS_NEW: - if (!sock->ops->gettstamp) { + if (!ops->gettstamp) { err = -ENOIOCTLCMD; break; } - err = sock->ops->gettstamp(sock, argp, - cmd == SIOCGSTAMP_NEW, - false); + err = ops->gettstamp(sock, argp, + cmd == SIOCGSTAMP_NEW, + false); break; case SIOCGIFCONF: @@ -1289,9 +1425,10 @@ EXPORT_SYMBOL(sock_create_lite); static __poll_t sock_poll(struct file *file, poll_table *wait) { struct socket *sock = file->private_data; + const struct proto_ops *ops = READ_ONCE(sock->ops); __poll_t events = poll_requested_events(wait), flag = 0; - if (!sock->ops->poll) + if (!ops->poll) return 0; if (sk_can_busy_loop(sock->sk)) { @@ -1303,14 +1440,14 @@ static __poll_t sock_poll(struct file *file, poll_table *wait) flag = POLL_BUSY_LOOP; } - return sock->ops->poll(file, sock, wait) | flag; + return ops->poll(file, sock, wait) | flag; } static int sock_mmap(struct file *file, struct vm_area_struct *vma) { struct socket *sock = file->private_data; - return sock->ops->mmap(file, sock, vma); + return READ_ONCE(sock->ops)->mmap(file, sock, vma); } static int sock_close(struct inode *inode, struct file *filp) @@ -1466,8 +1603,15 @@ int __sock_create(struct net *net, int family, int type, int protocol, rcu_read_unlock(); err = pf->create(net, sock, protocol, kern); - if (err < 0) + if (err < 0) { + /* ->create should release the allocated sock->sk object on error + * and make sure sock->sk is set to NULL to avoid use-after-free + */ + DEBUG_NET_WARN_ONCE(sock->sk, + "%ps must clear sock->sk on failure, family: %d, type: %d, protocol: %d\n", + pf->create, family, type, protocol); goto out_module_put; + } /* * Now to bump the refcnt of the [loadable] module that owns this @@ -1538,11 +1682,10 @@ int sock_create_kern(struct net *net, int family, int type, int protocol, struct } EXPORT_SYMBOL(sock_create_kern); -int __sys_socket(int family, int type, int protocol) +static struct socket *__sys_socket_create(int family, int type, int protocol) { - int retval; struct socket *sock; - int flags; + int retval; /* Check the SOCK_* constants for consistency. */ BUILD_BUG_ON(SOCK_CLOEXEC != O_CLOEXEC); @@ -1550,17 +1693,65 @@ int __sys_socket(int family, int type, int protocol) BUILD_BUG_ON(SOCK_CLOEXEC & SOCK_TYPE_MASK); BUILD_BUG_ON(SOCK_NONBLOCK & SOCK_TYPE_MASK); - flags = type & ~SOCK_TYPE_MASK; - if (flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK)) - return -EINVAL; + if ((type & ~SOCK_TYPE_MASK) & ~(SOCK_CLOEXEC | SOCK_NONBLOCK)) + return ERR_PTR(-EINVAL); type &= SOCK_TYPE_MASK; + retval = sock_create(family, type, protocol, &sock); + if (retval < 0) + return ERR_PTR(retval); + + return sock; +} + +struct file *__sys_socket_file(int family, int type, int protocol) +{ + struct socket *sock; + int flags; + + sock = __sys_socket_create(family, type, protocol); + if (IS_ERR(sock)) + return ERR_CAST(sock); + + flags = type & ~SOCK_TYPE_MASK; if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK)) flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK; - retval = sock_create(family, type, protocol, &sock); - if (retval < 0) - return retval; + return sock_alloc_file(sock, flags, NULL); +} + +/* A hook for bpf progs to attach to and update socket protocol. + * + * A static noinline declaration here could cause the compiler to + * optimize away the function. A global noinline declaration will + * keep the definition, but may optimize away the callsite. + * Therefore, __weak is needed to ensure that the call is still + * emitted, by telling the compiler that we don't know what the + * function might eventually be. + */ + +__bpf_hook_start(); + +__weak noinline int update_socket_protocol(int family, int type, int protocol) +{ + return protocol; +} + +__bpf_hook_end(); + +int __sys_socket(int family, int type, int protocol) +{ + struct socket *sock; + int flags; + + sock = __sys_socket_create(family, type, + update_socket_protocol(family, type, protocol)); + if (IS_ERR(sock)) + return PTR_ERR(sock); + + flags = type & ~SOCK_TYPE_MASK; + if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK)) + flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK; return sock_map_fd(sock, flags & (O_CLOEXEC | O_NONBLOCK)); } @@ -1633,7 +1824,7 @@ int __sys_socketpair(int family, int type, int protocol, int __user *usockvec) goto out; } - err = sock1->ops->socketpair(sock1, sock2); + err = READ_ONCE(sock1->ops)->socketpair(sock1, sock2); if (unlikely(err < 0)) { sock_release(sock2); sock_release(sock1); @@ -1672,6 +1863,20 @@ SYSCALL_DEFINE4(socketpair, int, family, int, type, int, protocol, return __sys_socketpair(family, type, protocol, usockvec); } +int __sys_bind_socket(struct socket *sock, struct sockaddr_storage *address, + int addrlen) +{ + int err; + + err = security_socket_bind(sock, (struct sockaddr *)address, + addrlen); + if (!err) + err = READ_ONCE(sock->ops)->bind(sock, + (struct sockaddr_unsized *)address, + addrlen); + return err; +} + /* * Bind a name to a socket. Nothing much to do here since it's * the protocol's responsibility to handle the local address. @@ -1684,23 +1889,20 @@ int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen) { struct socket *sock; struct sockaddr_storage address; - int err, fput_needed; - - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (sock) { - err = move_addr_to_kernel(umyaddr, addrlen, &address); - if (!err) { - err = security_socket_bind(sock, - (struct sockaddr *)&address, - addrlen); - if (!err) - err = sock->ops->bind(sock, - (struct sockaddr *) - &address, addrlen); - } - fput_light(sock->file, fput_needed); - } - return err; + CLASS(fd, f)(fd); + int err; + + if (fd_empty(f)) + return -EBADF; + sock = sock_from_file(fd_file(f)); + if (unlikely(!sock)) + return -ENOTSOCK; + + err = move_addr_to_kernel(umyaddr, addrlen, &address); + if (unlikely(err)) + return err; + + return __sys_bind_socket(sock, &address, addrlen); } SYSCALL_DEFINE3(bind, int, fd, struct sockaddr __user *, umyaddr, int, addrlen) @@ -1713,26 +1915,32 @@ SYSCALL_DEFINE3(bind, int, fd, struct sockaddr __user *, umyaddr, int, addrlen) * necessary for a listen, and if that works, we mark the socket as * ready for listening. */ +int __sys_listen_socket(struct socket *sock, int backlog) +{ + int somaxconn, err; + + somaxconn = READ_ONCE(sock_net(sock->sk)->core.sysctl_somaxconn); + if ((unsigned int)backlog > somaxconn) + backlog = somaxconn; + + err = security_socket_listen(sock, backlog); + if (!err) + err = READ_ONCE(sock->ops)->listen(sock, backlog); + return err; +} int __sys_listen(int fd, int backlog) { + CLASS(fd, f)(fd); struct socket *sock; - int err, fput_needed; - int somaxconn; - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (sock) { - somaxconn = sock_net(sock->sk)->core.sysctl_somaxconn; - if ((unsigned int)backlog > somaxconn) - backlog = somaxconn; + if (fd_empty(f)) + return -EBADF; + sock = sock_from_file(fd_file(f)); + if (unlikely(!sock)) + return -ENOTSOCK; - err = security_socket_listen(sock, backlog); - if (!err) - err = sock->ops->listen(sock, backlog); - - fput_light(sock->file, fput_needed); - } - return err; + return __sys_listen_socket(sock, backlog); } SYSCALL_DEFINE2(listen, int, fd, int, backlog) @@ -1740,7 +1948,7 @@ SYSCALL_DEFINE2(listen, int, fd, int, backlog) return __sys_listen(fd, backlog); } -struct file *do_accept(struct file *file, unsigned file_flags, +struct file *do_accept(struct file *file, struct proto_accept_arg *arg, struct sockaddr __user *upeer_sockaddr, int __user *upeer_addrlen, int flags) { @@ -1748,6 +1956,7 @@ struct file *do_accept(struct file *file, unsigned file_flags, struct file *newfile; int err, len; struct sockaddr_storage address; + const struct proto_ops *ops; sock = sock_from_file(file); if (!sock) @@ -1756,15 +1965,16 @@ struct file *do_accept(struct file *file, unsigned file_flags, newsock = sock_alloc(); if (!newsock) return ERR_PTR(-ENFILE); + ops = READ_ONCE(sock->ops); newsock->type = sock->type; - newsock->ops = sock->ops; + newsock->ops = ops; /* * We don't need try_module_get here, as the listening socket (sock) * has the protocol module (sock->ops->owner) held. */ - __module_get(newsock->ops->owner); + __module_get(ops->owner); newfile = sock_alloc_file(newsock, flags, sock->sk->sk_prot_creator->name); if (IS_ERR(newfile)) @@ -1774,14 +1984,13 @@ struct file *do_accept(struct file *file, unsigned file_flags, if (err) goto out_fd; - err = sock->ops->accept(sock, newsock, sock->file->f_flags | file_flags, - false); + arg->flags |= sock->file->f_flags; + err = ops->accept(sock, newsock, arg); if (err < 0) goto out_fd; if (upeer_sockaddr) { - len = newsock->ops->getname(newsock, - (struct sockaddr *)&address, 2); + len = ops->getname(newsock, (struct sockaddr *)&address, 2); if (len < 0) { err = -ECONNABORTED; goto out_fd; @@ -1799,13 +2008,10 @@ out_fd: return ERR_PTR(err); } -int __sys_accept4_file(struct file *file, unsigned file_flags, - struct sockaddr __user *upeer_sockaddr, - int __user *upeer_addrlen, int flags, - unsigned long nofile) +static int __sys_accept4_file(struct file *file, struct sockaddr __user *upeer_sockaddr, + int __user *upeer_addrlen, int flags) { - struct file *newfile; - int newfd; + struct proto_accept_arg arg = { }; if (flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK)) return -EINVAL; @@ -1813,18 +2019,7 @@ int __sys_accept4_file(struct file *file, unsigned file_flags, if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK)) flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK; - newfd = __get_unused_fd_flags(flags, nofile); - if (unlikely(newfd < 0)) - return newfd; - - newfile = do_accept(file, file_flags, upeer_sockaddr, upeer_addrlen, - flags); - if (IS_ERR(newfile)) { - put_unused_fd(newfd); - return PTR_ERR(newfile); - } - fd_install(newfd, newfile); - return newfd; + return FD_ADD(flags, do_accept(file, &arg, upeer_sockaddr, upeer_addrlen, flags)); } /* @@ -1842,18 +2037,12 @@ int __sys_accept4_file(struct file *file, unsigned file_flags, int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr, int __user *upeer_addrlen, int flags) { - int ret = -EBADF; - struct fd f; - - f = fdget(fd); - if (f.file) { - ret = __sys_accept4_file(f.file, 0, upeer_sockaddr, - upeer_addrlen, flags, - rlimit(RLIMIT_NOFILE)); - fdput(f); - } + CLASS(fd, f)(fd); - return ret; + if (fd_empty(f)) + return -EBADF; + return __sys_accept4_file(fd_file(f), upeer_sockaddr, + upeer_addrlen, flags); } SYSCALL_DEFINE4(accept4, int, fd, struct sockaddr __user *, upeer_sockaddr, @@ -1897,28 +2086,26 @@ int __sys_connect_file(struct file *file, struct sockaddr_storage *address, if (err) goto out; - err = sock->ops->connect(sock, (struct sockaddr *)address, addrlen, - sock->file->f_flags | file_flags); + err = READ_ONCE(sock->ops)->connect(sock, (struct sockaddr_unsized *)address, + addrlen, sock->file->f_flags | file_flags); out: return err; } int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen) { - int ret = -EBADF; - struct fd f; + struct sockaddr_storage address; + CLASS(fd, f)(fd); + int ret; - f = fdget(fd); - if (f.file) { - struct sockaddr_storage address; + if (fd_empty(f)) + return -EBADF; - ret = move_addr_to_kernel(uservaddr, addrlen, &address); - if (!ret) - ret = __sys_connect_file(f.file, &address, addrlen, 0); - fdput(f); - } + ret = move_addr_to_kernel(uservaddr, addrlen, &address); + if (ret) + return ret; - return ret; + return __sys_connect_file(fd_file(f), &address, addrlen, 0); } SYSCALL_DEFINE3(connect, int, fd, struct sockaddr __user *, uservaddr, @@ -1927,78 +2114,53 @@ SYSCALL_DEFINE3(connect, int, fd, struct sockaddr __user *, uservaddr, return __sys_connect(fd, uservaddr, addrlen); } -/* - * Get the local address ('name') of a socket object. Move the obtained - * name to user space. - */ - -int __sys_getsockname(int fd, struct sockaddr __user *usockaddr, - int __user *usockaddr_len) +int do_getsockname(struct socket *sock, int peer, + struct sockaddr __user *usockaddr, int __user *usockaddr_len) { - struct socket *sock; struct sockaddr_storage address; - int err, fput_needed; - - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (!sock) - goto out; + int err; - err = security_socket_getsockname(sock); + if (peer) + err = security_socket_getpeername(sock); + else + err = security_socket_getsockname(sock); if (err) - goto out_put; - - err = sock->ops->getname(sock, (struct sockaddr *)&address, 0); + return err; + err = READ_ONCE(sock->ops)->getname(sock, (struct sockaddr *)&address, peer); if (err < 0) - goto out_put; + return err; /* "err" is actually length in this case */ - err = move_addr_to_user(&address, err, usockaddr, usockaddr_len); - -out_put: - fput_light(sock->file, fput_needed); -out: - return err; -} - -SYSCALL_DEFINE3(getsockname, int, fd, struct sockaddr __user *, usockaddr, - int __user *, usockaddr_len) -{ - return __sys_getsockname(fd, usockaddr, usockaddr_len); + return move_addr_to_user(&address, err, usockaddr, usockaddr_len); } /* - * Get the remote address ('name') of a socket object. Move the obtained - * name to user space. + * Get the remote or local address ('name') of a socket object. Move the + * obtained name to user space. */ - -int __sys_getpeername(int fd, struct sockaddr __user *usockaddr, - int __user *usockaddr_len) +int __sys_getsockname(int fd, struct sockaddr __user *usockaddr, + int __user *usockaddr_len, int peer) { struct socket *sock; - struct sockaddr_storage address; - int err, fput_needed; + CLASS(fd, f)(fd); - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (sock != NULL) { - err = security_socket_getpeername(sock); - if (err) { - fput_light(sock->file, fput_needed); - return err; - } + if (fd_empty(f)) + return -EBADF; + sock = sock_from_file(fd_file(f)); + if (unlikely(!sock)) + return -ENOTSOCK; + return do_getsockname(sock, peer, usockaddr, usockaddr_len); +} - err = sock->ops->getname(sock, (struct sockaddr *)&address, 1); - if (err >= 0) - /* "err" is actually length in this case */ - err = move_addr_to_user(&address, err, usockaddr, - usockaddr_len); - fput_light(sock->file, fput_needed); - } - return err; +SYSCALL_DEFINE3(getsockname, int, fd, struct sockaddr __user *, usockaddr, + int __user *, usockaddr_len) +{ + return __sys_getsockname(fd, usockaddr, usockaddr_len, 0); } SYSCALL_DEFINE3(getpeername, int, fd, struct sockaddr __user *, usockaddr, int __user *, usockaddr_len) { - return __sys_getpeername(fd, usockaddr, usockaddr_len); + return __sys_getsockname(fd, usockaddr, usockaddr_len, 1); } /* @@ -2013,36 +2175,35 @@ int __sys_sendto(int fd, void __user *buff, size_t len, unsigned int flags, struct sockaddr_storage address; int err; struct msghdr msg; - struct iovec iov; - int fput_needed; - err = import_single_range(WRITE, buff, len, &iov, &msg.msg_iter); + err = import_ubuf(ITER_SOURCE, buff, len, &msg.msg_iter); if (unlikely(err)) return err; - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (!sock) - goto out; + + CLASS(fd, f)(fd); + if (fd_empty(f)) + return -EBADF; + sock = sock_from_file(fd_file(f)); + if (unlikely(!sock)) + return -ENOTSOCK; msg.msg_name = NULL; msg.msg_control = NULL; msg.msg_controllen = 0; msg.msg_namelen = 0; + msg.msg_ubuf = NULL; if (addr) { err = move_addr_to_kernel(addr, addr_len, &address); if (err < 0) - goto out_put; + return err; msg.msg_name = (struct sockaddr *)&address; msg.msg_namelen = addr_len; } + flags &= ~MSG_INTERNAL_SENDMSG_FLAGS; if (sock->file->f_flags & O_NONBLOCK) flags |= MSG_DONTWAIT; msg.msg_flags = flags; - err = sock_sendmsg(sock, &msg); - -out_put: - fput_light(sock->file, fput_needed); -out: - return err; + return __sock_sendmsg(sock, &msg); } SYSCALL_DEFINE6(sendto, int, fd, void __user *, buff, size_t, len, @@ -2070,28 +2231,26 @@ SYSCALL_DEFINE4(send, int, fd, void __user *, buff, size_t, len, int __sys_recvfrom(int fd, void __user *ubuf, size_t size, unsigned int flags, struct sockaddr __user *addr, int __user *addr_len) { - struct socket *sock; - struct iovec iov; - struct msghdr msg; struct sockaddr_storage address; + struct msghdr msg = { + /* Save some cycles and don't copy the address if not needed */ + .msg_name = addr ? (struct sockaddr *)&address : NULL, + }; + struct socket *sock; int err, err2; - int fput_needed; - err = import_single_range(READ, ubuf, size, &iov, &msg.msg_iter); + err = import_ubuf(ITER_DEST, ubuf, size, &msg.msg_iter); if (unlikely(err)) return err; - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (!sock) - goto out; - msg.msg_control = NULL; - msg.msg_controllen = 0; - /* Save some cycles and don't copy the address if not needed */ - msg.msg_name = addr ? (struct sockaddr *)&address : NULL; - /* We assume all kernel code knows the size of sockaddr_storage */ - msg.msg_namelen = 0; - msg.msg_iocb = NULL; - msg.msg_flags = 0; + CLASS(fd, f)(fd); + + if (fd_empty(f)) + return -EBADF; + sock = sock_from_file(fd_file(f)); + if (unlikely(!sock)) + return -ENOTSOCK; + if (sock->file->f_flags & O_NONBLOCK) flags |= MSG_DONTWAIT; err = sock_recvmsg(sock, &msg, flags); @@ -2102,9 +2261,6 @@ int __sys_recvfrom(int fd, void __user *ubuf, size_t size, unsigned int flags, if (err2 < 0) err = err2; } - - fput_light(sock->file, fput_needed); -out: return err; } @@ -2127,41 +2283,26 @@ SYSCALL_DEFINE4(recv, int, fd, void __user *, ubuf, size_t, size, static bool sock_use_custom_sol_socket(const struct socket *sock) { - const struct sock *sk = sock->sk; - - /* Use sock->ops->setsockopt() for MPTCP */ - return IS_ENABLED(CONFIG_MPTCP) && - sk->sk_protocol == IPPROTO_MPTCP && - sk->sk_type == SOCK_STREAM && - (sk->sk_family == AF_INET || sk->sk_family == AF_INET6); + return test_bit(SOCK_CUSTOM_SOCKOPT, &sock->flags); } -/* - * Set a socket option. Because we don't know the option lengths we have - * to pass the user mode parameter for the protocols to sort out. - */ -int __sys_setsockopt(int fd, int level, int optname, char __user *user_optval, - int optlen) +int do_sock_setsockopt(struct socket *sock, bool compat, int level, + int optname, sockptr_t optval, int optlen) { - sockptr_t optval = USER_SOCKPTR(user_optval); + const struct proto_ops *ops; char *kernel_optval = NULL; - int err, fput_needed; - struct socket *sock; + int err; if (optlen < 0) return -EINVAL; - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (!sock) - return err; - err = security_socket_setsockopt(sock, level, optname); if (err) goto out_put; - if (!in_compat_syscall()) + if (!compat) err = BPF_CGROUP_RUN_PROG_SETSOCKOPT(sock->sk, &level, &optname, - user_optval, &optlen, + optval, &optlen, &kernel_optval); if (err < 0) goto out_put; @@ -2172,18 +2313,39 @@ int __sys_setsockopt(int fd, int level, int optname, char __user *user_optval, if (kernel_optval) optval = KERNEL_SOCKPTR(kernel_optval); + ops = READ_ONCE(sock->ops); if (level == SOL_SOCKET && !sock_use_custom_sol_socket(sock)) err = sock_setsockopt(sock, level, optname, optval, optlen); - else if (unlikely(!sock->ops->setsockopt)) + else if (unlikely(!ops->setsockopt)) err = -EOPNOTSUPP; else - err = sock->ops->setsockopt(sock, level, optname, optval, + err = ops->setsockopt(sock, level, optname, optval, optlen); kfree(kernel_optval); out_put: - fput_light(sock->file, fput_needed); return err; } +EXPORT_SYMBOL(do_sock_setsockopt); + +/* Set a socket option. Because we don't know the option lengths we have + * to pass the user mode parameter for the protocols to sort out. + */ +int __sys_setsockopt(int fd, int level, int optname, char __user *user_optval, + int optlen) +{ + sockptr_t optval = USER_SOCKPTR(user_optval); + bool compat = in_compat_syscall(); + struct socket *sock; + CLASS(fd, f)(fd); + + if (fd_empty(f)) + return -EBADF; + sock = sock_from_file(fd_file(f)); + if (unlikely(!sock)) + return -ENOTSOCK; + + return do_sock_setsockopt(sock, compat, level, optname, optval, optlen); +} SYSCALL_DEFINE5(setsockopt, int, fd, int, level, int, optname, char __user *, optval, int, optlen) @@ -2194,44 +2356,62 @@ SYSCALL_DEFINE5(setsockopt, int, fd, int, level, int, optname, INDIRECT_CALLABLE_DECLARE(bool tcp_bpf_bypass_getsockopt(int level, int optname)); -/* - * Get a socket option. Because we don't know the option lengths we have - * to pass a user mode parameter for the protocols to sort out. - */ -int __sys_getsockopt(int fd, int level, int optname, char __user *optval, - int __user *optlen) +int do_sock_getsockopt(struct socket *sock, bool compat, int level, + int optname, sockptr_t optval, sockptr_t optlen) { - int err, fput_needed; - struct socket *sock; - int max_optlen; - - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (!sock) - return err; + int max_optlen __maybe_unused = 0; + const struct proto_ops *ops; + int err; err = security_socket_getsockopt(sock, level, optname); if (err) - goto out_put; + return err; - if (!in_compat_syscall()) - max_optlen = BPF_CGROUP_GETSOCKOPT_MAX_OPTLEN(optlen); + if (!compat) + copy_from_sockptr(&max_optlen, optlen, sizeof(int)); - if (level == SOL_SOCKET) - err = sock_getsockopt(sock, level, optname, optval, optlen); - else if (unlikely(!sock->ops->getsockopt)) + ops = READ_ONCE(sock->ops); + if (level == SOL_SOCKET) { + err = sk_getsockopt(sock->sk, level, optname, optval, optlen); + } else if (unlikely(!ops->getsockopt)) { err = -EOPNOTSUPP; - else - err = sock->ops->getsockopt(sock, level, optname, optval, - optlen); + } else { + if (WARN_ONCE(optval.is_kernel || optlen.is_kernel, + "Invalid argument type")) + return -EOPNOTSUPP; + + err = ops->getsockopt(sock, level, optname, optval.user, + optlen.user); + } - if (!in_compat_syscall()) + if (!compat) err = BPF_CGROUP_RUN_PROG_GETSOCKOPT(sock->sk, level, optname, optval, optlen, max_optlen, err); -out_put: - fput_light(sock->file, fput_needed); + return err; } +EXPORT_SYMBOL(do_sock_getsockopt); + +/* + * Get a socket option. Because we don't know the option lengths we have + * to pass a user mode parameter for the protocols to sort out. + */ +int __sys_getsockopt(int fd, int level, int optname, char __user *optval, + int __user *optlen) +{ + struct socket *sock; + CLASS(fd, f)(fd); + + if (fd_empty(f)) + return -EBADF; + sock = sock_from_file(fd_file(f)); + if (unlikely(!sock)) + return -ENOTSOCK; + + return do_sock_getsockopt(sock, in_compat_syscall(), level, optname, + USER_SOCKPTR(optval), USER_SOCKPTR(optlen)); +} SYSCALL_DEFINE5(getsockopt, int, fd, int, level, int, optname, char __user *, optval, int __user *, optlen) @@ -2249,22 +2429,23 @@ int __sys_shutdown_sock(struct socket *sock, int how) err = security_socket_shutdown(sock, how); if (!err) - err = sock->ops->shutdown(sock, how); + err = READ_ONCE(sock->ops)->shutdown(sock, how); return err; } int __sys_shutdown(int fd, int how) { - int err, fput_needed; struct socket *sock; + CLASS(fd, f)(fd); - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (sock != NULL) { - err = __sys_shutdown_sock(sock, how); - fput_light(sock->file, fput_needed); - } - return err; + if (fd_empty(f)) + return -EBADF; + sock = sock_from_file(fd_file(f)); + if (unlikely(!sock)) + return -ENOTSOCK; + + return __sys_shutdown_sock(sock, how); } SYSCALL_DEFINE2(shutdown, int, fd, int, how) @@ -2284,24 +2465,20 @@ struct used_address { unsigned int name_len; }; -int __copy_msghdr_from_user(struct msghdr *kmsg, - struct user_msghdr __user *umsg, - struct sockaddr __user **save_addr, - struct iovec __user **uiov, size_t *nsegs) +int __copy_msghdr(struct msghdr *kmsg, + struct user_msghdr *msg, + struct sockaddr __user **save_addr) { - struct user_msghdr msg; ssize_t err; - if (copy_from_user(&msg, umsg, sizeof(*umsg))) - return -EFAULT; - kmsg->msg_control_is_user = true; - kmsg->msg_control_user = msg.msg_control; - kmsg->msg_controllen = msg.msg_controllen; - kmsg->msg_flags = msg.msg_flags; + kmsg->msg_get_inq = 0; + kmsg->msg_control_user = msg->msg_control; + kmsg->msg_controllen = msg->msg_controllen; + kmsg->msg_flags = msg->msg_flags; - kmsg->msg_namelen = msg.msg_namelen; - if (!msg.msg_name) + kmsg->msg_namelen = msg->msg_namelen; + if (!msg->msg_name) kmsg->msg_namelen = 0; if (kmsg->msg_namelen < 0) @@ -2311,11 +2488,11 @@ int __copy_msghdr_from_user(struct msghdr *kmsg, kmsg->msg_namelen = sizeof(struct sockaddr_storage); if (save_addr) - *save_addr = msg.msg_name; + *save_addr = msg->msg_name; - if (msg.msg_name && kmsg->msg_namelen) { + if (msg->msg_name && kmsg->msg_namelen) { if (!save_addr) { - err = move_addr_to_kernel(msg.msg_name, + err = move_addr_to_kernel(msg->msg_name, kmsg->msg_namelen, kmsg->msg_name); if (err < 0) @@ -2326,12 +2503,11 @@ int __copy_msghdr_from_user(struct msghdr *kmsg, kmsg->msg_namelen = 0; } - if (msg.msg_iovlen > UIO_MAXIOV) + if (msg->msg_iovlen > UIO_MAXIOV) return -EMSGSIZE; kmsg->msg_iocb = NULL; - *uiov = msg.msg_iov; - *nsegs = msg.msg_iovlen; + kmsg->msg_ubuf = NULL; return 0; } @@ -2343,12 +2519,14 @@ static int copy_msghdr_from_user(struct msghdr *kmsg, struct user_msghdr msg; ssize_t err; - err = __copy_msghdr_from_user(kmsg, umsg, save_addr, &msg.msg_iov, - &msg.msg_iovlen); + if (copy_from_user(&msg, umsg, sizeof(*umsg))) + return -EFAULT; + + err = __copy_msghdr(kmsg, &msg, save_addr); if (err) return err; - err = import_iovec(save_addr ? READ : WRITE, + err = import_iovec(save_addr ? ITER_DEST : ITER_SOURCE, msg.msg_iov, msg.msg_iovlen, UIO_FASTIOV, iov, &kmsg->msg_iter); return err < 0 ? err : 0; @@ -2393,6 +2571,7 @@ static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys, msg_sys->msg_control = ctl_buf; msg_sys->msg_control_is_user = false; } + flags &= ~MSG_INTERNAL_SENDMSG_FLAGS; msg_sys->msg_flags = flags; if (sock->file->f_flags & O_NONBLOCK) @@ -2410,7 +2589,7 @@ static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys, err = sock_sendmsg_nosec(sock, msg_sys); goto out_freectl; } - err = sock_sendmsg(sock, msg_sys); + err = __sock_sendmsg(sock, msg_sys); /* * If this is sendmmsg() and sending to current destination address was * successful, remember it. @@ -2429,9 +2608,9 @@ out: return err; } -int sendmsg_copy_msghdr(struct msghdr *msg, - struct user_msghdr __user *umsg, unsigned flags, - struct iovec **iov) +static int sendmsg_copy_msghdr(struct msghdr *msg, + struct user_msghdr __user *umsg, unsigned flags, + struct iovec **iov) { int err; @@ -2482,22 +2661,21 @@ long __sys_sendmsg_sock(struct socket *sock, struct msghdr *msg, long __sys_sendmsg(int fd, struct user_msghdr __user *msg, unsigned int flags, bool forbid_cmsg_compat) { - int fput_needed, err; struct msghdr msg_sys; struct socket *sock; if (forbid_cmsg_compat && (flags & MSG_CMSG_COMPAT)) return -EINVAL; - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (!sock) - goto out; + CLASS(fd, f)(fd); - err = ___sys_sendmsg(sock, msg, &msg_sys, flags, NULL, 0); + if (fd_empty(f)) + return -EBADF; + sock = sock_from_file(fd_file(f)); + if (unlikely(!sock)) + return -ENOTSOCK; - fput_light(sock->file, fput_needed); -out: - return err; + return ___sys_sendmsg(sock, msg, &msg_sys, flags, NULL, 0); } SYSCALL_DEFINE3(sendmsg, int, fd, struct user_msghdr __user *, msg, unsigned int, flags) @@ -2512,7 +2690,7 @@ SYSCALL_DEFINE3(sendmsg, int, fd, struct user_msghdr __user *, msg, unsigned int int __sys_sendmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen, unsigned int flags, bool forbid_cmsg_compat) { - int fput_needed, err, datagrams; + int err, datagrams; struct socket *sock; struct mmsghdr __user *entry; struct compat_mmsghdr __user *compat_entry; @@ -2528,9 +2706,13 @@ int __sys_sendmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen, datagrams = 0; - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (!sock) - return err; + CLASS(fd, f)(fd); + + if (fd_empty(f)) + return -EBADF; + sock = sock_from_file(fd_file(f)); + if (unlikely(!sock)) + return -ENOTSOCK; used_address.name_len = UINT_MAX; entry = mmsg; @@ -2567,8 +2749,6 @@ int __sys_sendmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen, cond_resched(); } - fput_light(sock->file, fput_needed); - /* We only return an error if no datagrams were able to be sent */ if (datagrams != 0) return datagrams; @@ -2582,10 +2762,10 @@ SYSCALL_DEFINE4(sendmmsg, int, fd, struct mmsghdr __user *, mmsg, return __sys_sendmmsg(fd, mmsg, vlen, flags, true); } -int recvmsg_copy_msghdr(struct msghdr *msg, - struct user_msghdr __user *umsg, unsigned flags, - struct sockaddr __user **uaddr, - struct iovec **iov) +static int recvmsg_copy_msghdr(struct msghdr *msg, + struct user_msghdr __user *umsg, unsigned flags, + struct sockaddr __user **uaddr, + struct iovec **iov) { ssize_t err; @@ -2690,22 +2870,21 @@ long __sys_recvmsg_sock(struct socket *sock, struct msghdr *msg, long __sys_recvmsg(int fd, struct user_msghdr __user *msg, unsigned int flags, bool forbid_cmsg_compat) { - int fput_needed, err; struct msghdr msg_sys; struct socket *sock; if (forbid_cmsg_compat && (flags & MSG_CMSG_COMPAT)) return -EINVAL; - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (!sock) - goto out; + CLASS(fd, f)(fd); - err = ___sys_recvmsg(sock, msg, &msg_sys, flags, 0); + if (fd_empty(f)) + return -EBADF; + sock = sock_from_file(fd_file(f)); + if (unlikely(!sock)) + return -ENOTSOCK; - fput_light(sock->file, fput_needed); -out: - return err; + return ___sys_recvmsg(sock, msg, &msg_sys, flags, 0); } SYSCALL_DEFINE3(recvmsg, int, fd, struct user_msghdr __user *, msg, @@ -2722,7 +2901,7 @@ static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen, unsigned int flags, struct timespec64 *timeout) { - int fput_needed, err, datagrams; + int err = 0, datagrams; struct socket *sock; struct mmsghdr __user *entry; struct compat_mmsghdr __user *compat_entry; @@ -2737,16 +2916,18 @@ static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg, datagrams = 0; - sock = sockfd_lookup_light(fd, &err, &fput_needed); - if (!sock) - return err; + CLASS(fd, f)(fd); + + if (fd_empty(f)) + return -EBADF; + sock = sock_from_file(fd_file(f)); + if (unlikely(!sock)) + return -ENOTSOCK; if (likely(!(flags & MSG_ERRQUEUE))) { err = sock_error(sock->sk); - if (err) { - datagrams = err; - goto out_put; - } + if (err) + return err; } entry = mmsg; @@ -2803,12 +2984,10 @@ static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg, } if (err == 0) - goto out_put; + return datagrams; - if (datagrams == 0) { - datagrams = err; - goto out_put; - } + if (datagrams == 0) + return err; /* * We may return less entries than requested (vlen) if the @@ -2821,11 +3000,8 @@ static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg, * error to return on the next call or if the * app asks about it using getsockopt(SO_ERROR). */ - sock->sk->sk_err = -err; + WRITE_ONCE(sock->sk->sk_err, -err); } -out_put: - fput_light(sock->file, fput_needed); - return datagrams; } @@ -2948,12 +3124,12 @@ SYSCALL_DEFINE2(socketcall, int, call, unsigned long __user *, args) case SYS_GETSOCKNAME: err = __sys_getsockname(a0, (struct sockaddr __user *)a1, - (int __user *)a[2]); + (int __user *)a[2], 0); break; case SYS_GETPEERNAME: err = - __sys_getpeername(a0, (struct sockaddr __user *)a1, - (int __user *)a[2]); + __sys_getsockname(a0, (struct sockaddr __user *)a1, + (int __user *)a[2], 1); break; case SYS_SOCKETPAIR: err = __sys_socketpair(a0, a1, a[2], (int __user *)a[3]); @@ -3243,6 +3419,7 @@ static int compat_sock_ioctl_trans(struct file *file, struct socket *sock, void __user *argp = compat_ptr(arg); struct sock *sk = sock->sk; struct net *net = sock_net(sk); + const struct proto_ops *ops; if (cmd >= SIOCDEVPRIVATE && cmd <= (SIOCDEVPRIVATE + 15)) return sock_ioctl(file, cmd, (unsigned long)argp); @@ -3252,10 +3429,11 @@ static int compat_sock_ioctl_trans(struct file *file, struct socket *sock, return compat_siocwandev(net, argp); case SIOCGSTAMP_OLD: case SIOCGSTAMPNS_OLD: - if (!sock->ops->gettstamp) + ops = READ_ONCE(sock->ops); + if (!ops->gettstamp) return -ENOIOCTLCMD; - return sock->ops->gettstamp(sock, argp, cmd == SIOCGSTAMP_OLD, - !COMPAT_USE_64BIT_TIME); + return ops->gettstamp(sock, argp, cmd == SIOCGSTAMP_OLD, + !COMPAT_USE_64BIT_TIME); case SIOCETHTOOL: case SIOCBONDSLAVEINFOQUERY: @@ -3270,6 +3448,8 @@ static int compat_sock_ioctl_trans(struct file *file, struct socket *sock, case SIOCGPGRP: case SIOCBRADDBR: case SIOCBRDELBR: + case SIOCBRADDIF: + case SIOCBRDELIF: case SIOCGIFVLAN: case SIOCSIFVLAN: case SIOCGSKNS: @@ -3309,8 +3489,6 @@ static int compat_sock_ioctl_trans(struct file *file, struct socket *sock, case SIOCGIFPFLAGS: case SIOCGIFTXQLEN: case SIOCSIFTXQLEN: - case SIOCBRADDIF: - case SIOCBRDELIF: case SIOCGIFNAME: case SIOCSIFNAME: case SIOCGMIIPHY: @@ -3336,6 +3514,7 @@ static long compat_sock_ioctl(struct file *file, unsigned int cmd, unsigned long arg) { struct socket *sock = file->private_data; + const struct proto_ops *ops = READ_ONCE(sock->ops); int ret = -ENOIOCTLCMD; struct sock *sk; struct net *net; @@ -3343,8 +3522,8 @@ static long compat_sock_ioctl(struct file *file, unsigned int cmd, sk = sock->sk; net = sock_net(sk); - if (sock->ops->compat_ioctl) - ret = sock->ops->compat_ioctl(sock, cmd, arg); + if (ops->compat_ioctl) + ret = ops->compat_ioctl(sock, cmd, arg); if (ret == -ENOIOCTLCMD && (cmd >= SIOCIWFIRST && cmd <= SIOCIWLAST)) @@ -3366,9 +3545,14 @@ static long compat_sock_ioctl(struct file *file, unsigned int cmd, * Returns 0 or an error. */ -int kernel_bind(struct socket *sock, struct sockaddr *addr, int addrlen) +int kernel_bind(struct socket *sock, struct sockaddr_unsized *addr, int addrlen) { - return sock->ops->bind(sock, addr, addrlen); + struct sockaddr_storage address; + + memcpy(&address, addr, addrlen); + + return READ_ONCE(sock->ops)->bind(sock, (struct sockaddr_unsized *)&address, + addrlen); } EXPORT_SYMBOL(kernel_bind); @@ -3382,7 +3566,7 @@ EXPORT_SYMBOL(kernel_bind); int kernel_listen(struct socket *sock, int backlog) { - return sock->ops->listen(sock, backlog); + return READ_ONCE(sock->ops)->listen(sock, backlog); } EXPORT_SYMBOL(kernel_listen); @@ -3400,6 +3584,11 @@ EXPORT_SYMBOL(kernel_listen); int kernel_accept(struct socket *sock, struct socket **newsock, int flags) { struct sock *sk = sock->sk; + const struct proto_ops *ops = READ_ONCE(sock->ops); + struct proto_accept_arg arg = { + .flags = flags, + .kern = true, + }; int err; err = sock_create_lite(sk->sk_family, sk->sk_type, sk->sk_protocol, @@ -3407,15 +3596,15 @@ int kernel_accept(struct socket *sock, struct socket **newsock, int flags) if (err < 0) goto done; - err = sock->ops->accept(sock, *newsock, flags, true); + err = ops->accept(sock, *newsock, &arg); if (err < 0) { sock_release(*newsock); *newsock = NULL; goto done; } - (*newsock)->ops = sock->ops; - __module_get((*newsock)->ops->owner); + (*newsock)->ops = ops; + __module_get(ops->owner); done: return err; @@ -3435,10 +3624,15 @@ EXPORT_SYMBOL(kernel_accept); * Returns 0 or an error code. */ -int kernel_connect(struct socket *sock, struct sockaddr *addr, int addrlen, +int kernel_connect(struct socket *sock, struct sockaddr_unsized *addr, int addrlen, int flags) { - return sock->ops->connect(sock, addr, addrlen, flags); + struct sockaddr_storage address; + + memcpy(&address, addr, addrlen); + + return READ_ONCE(sock->ops)->connect(sock, (struct sockaddr_unsized *)&address, + addrlen, flags); } EXPORT_SYMBOL(kernel_connect); @@ -3448,12 +3642,12 @@ EXPORT_SYMBOL(kernel_connect); * @addr: address holder * * Fills the @addr pointer with the address which the socket is bound. - * Returns 0 or an error code. + * Returns the length of the address in bytes or an error code. */ int kernel_getsockname(struct socket *sock, struct sockaddr *addr) { - return sock->ops->getname(sock, addr, 0); + return READ_ONCE(sock->ops)->getname(sock, addr, 0); } EXPORT_SYMBOL(kernel_getsockname); @@ -3463,64 +3657,16 @@ EXPORT_SYMBOL(kernel_getsockname); * @addr: address holder * * Fills the @addr pointer with the address which the socket is connected. - * Returns 0 or an error code. + * Returns the length of the address in bytes or an error code. */ int kernel_getpeername(struct socket *sock, struct sockaddr *addr) { - return sock->ops->getname(sock, addr, 1); + return READ_ONCE(sock->ops)->getname(sock, addr, 1); } EXPORT_SYMBOL(kernel_getpeername); /** - * kernel_sendpage - send a &page through a socket (kernel space) - * @sock: socket - * @page: page - * @offset: page offset - * @size: total size in bytes - * @flags: flags (MSG_DONTWAIT, ...) - * - * Returns the total amount sent in bytes or an error. - */ - -int kernel_sendpage(struct socket *sock, struct page *page, int offset, - size_t size, int flags) -{ - if (sock->ops->sendpage) { - /* Warn in case the improper page to zero-copy send */ - WARN_ONCE(!sendpage_ok(page), "improper page for zero-copy send"); - return sock->ops->sendpage(sock, page, offset, size, flags); - } - return sock_no_sendpage(sock, page, offset, size, flags); -} -EXPORT_SYMBOL(kernel_sendpage); - -/** - * kernel_sendpage_locked - send a &page through the locked sock (kernel space) - * @sk: sock - * @page: page - * @offset: page offset - * @size: total size in bytes - * @flags: flags (MSG_DONTWAIT, ...) - * - * Returns the total amount sent in bytes or an error. - * Caller must hold @sk. - */ - -int kernel_sendpage_locked(struct sock *sk, struct page *page, int offset, - size_t size, int flags) -{ - struct socket *sock = sk->sk_socket; - - if (sock->ops->sendpage_locked) - return sock->ops->sendpage_locked(sk, page, offset, size, - flags); - - return sock_no_sendpage_locked(sk, page, offset, size, flags); -} -EXPORT_SYMBOL(kernel_sendpage_locked); - -/** * kernel_sock_shutdown - shut down part of a full-duplex connection (kernel space) * @sock: socket * @how: connection part @@ -3530,7 +3676,7 @@ EXPORT_SYMBOL(kernel_sendpage_locked); int kernel_sock_shutdown(struct socket *sock, enum sock_shutdown_cmd how) { - return sock->ops->shutdown(sock, how); + return READ_ONCE(sock->ops)->shutdown(sock, how); } EXPORT_SYMBOL(kernel_sock_shutdown); |
