proto_ops: Add locked held versions of sendmsg and sendpage
authorTom Herbert <tom@quantonium.net>
Fri, 28 Jul 2017 23:22:41 +0000 (16:22 -0700)
committerDavid S. Miller <davem@davemloft.net>
Tue, 1 Aug 2017 22:26:18 +0000 (15:26 -0700)
Add new proto_ops sendmsg_locked and sendpage_locked that can be
called when the socket lock is already held. Correspondingly, add
kernel_sendmsg_locked and kernel_sendpage_locked as front end
functions.

These functions will be used in zero proxy so that we can take
the socket lock in a ULP sendmsg/sendpage and then directly call the
backend transport proto_ops functions.

Signed-off-by: Tom Herbert <tom@quantonium.net>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/linux/net.h
include/net/sock.h
include/net/tcp.h
net/core/sock.c
net/ipv4/af_inet.c
net/ipv4/tcp.c
net/socket.c

index dda2cc939a531dab67441c6ddf4b7869c6a06159..b5c15b31709b794a924587da591977139f0e56fe 100644 (file)
@@ -190,8 +190,16 @@ struct proto_ops {
                                       struct pipe_inode_info *pipe, size_t len, unsigned int flags);
        int             (*set_peek_off)(struct sock *sk, int val);
        int             (*peek_len)(struct socket *sock);
+
+       /* The following functions are called internally by kernel with
+        * sock lock already held.
+        */
        int             (*read_sock)(struct sock *sk, read_descriptor_t *desc,
                                     sk_read_actor_t recv_actor);
+       int             (*sendpage_locked)(struct sock *sk, struct page *page,
+                                          int offset, size_t size, int flags);
+       int             (*sendmsg_locked)(struct sock *sk, struct msghdr *msg,
+                                         size_t size);
 };
 
 #define DECLARE_SOCKADDR(type, dst, src)       \
@@ -279,6 +287,8 @@ do {                                                                        \
 
 int kernel_sendmsg(struct socket *sock, struct msghdr *msg, struct kvec *vec,
                   size_t num, size_t len);
+int kernel_sendmsg_locked(struct sock *sk, struct msghdr *msg,
+                         struct kvec *vec, size_t num, size_t len);
 int kernel_recvmsg(struct socket *sock, struct msghdr *msg, struct kvec *vec,
                   size_t num, size_t len, int flags);
 
@@ -297,6 +307,8 @@ int kernel_setsockopt(struct socket *sock, int level, int optname, char *optval,
                      unsigned int optlen);
 int kernel_sendpage(struct socket *sock, struct page *page, int offset,
                    size_t size, int flags);
+int kernel_sendpage_locked(struct sock *sk, struct page *page, int offset,
+                          size_t size, int flags);
 int kernel_sock_ioctl(struct socket *sock, int cmd, unsigned long arg);
 int kernel_sock_shutdown(struct socket *sock, enum sock_shutdown_cmd how);
 
index 7c0632c7e87043ca18fe5d32d7d55792f75ca6e8..393c38e9f6aa799a27681f2405750d43081788f9 100644 (file)
@@ -1582,11 +1582,14 @@ int sock_no_shutdown(struct socket *, int);
 int sock_no_getsockopt(struct socket *, int , int, char __user *, int __user *);
 int sock_no_setsockopt(struct socket *, int, int, char __user *, unsigned int);
 int sock_no_sendmsg(struct socket *, struct msghdr *, size_t);
+int sock_no_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t len);
 int sock_no_recvmsg(struct socket *, struct msghdr *, size_t, int);
 int sock_no_mmap(struct file *file, struct socket *sock,
                 struct vm_area_struct *vma);
 ssize_t sock_no_sendpage(struct socket *sock, struct page *page, int offset,
                         size_t size, int flags);
+ssize_t sock_no_sendpage_locked(struct sock *sk, struct page *page,
+                               int offset, size_t size, int flags);
 
 /*
  * Functions to fill in entries in struct proto_ops when a protocol
index 3ecb628110042de7a533335361a67b2c601eeb87..bb1881b4ce486b38ade45e5a7d9a2ba03bb7c279 100644 (file)
@@ -350,8 +350,11 @@ int tcp_v4_rcv(struct sk_buff *skb);
 
 int tcp_v4_tw_remember_stamp(struct inet_timewait_sock *tw);
 int tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
+int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size);
 int tcp_sendpage(struct sock *sk, struct page *page, int offset, size_t size,
                 int flags);
+int tcp_sendpage_locked(struct sock *sk, struct page *page, int offset,
+                       size_t size, int flags);
 ssize_t do_tcp_sendpages(struct sock *sk, struct page *page, int offset,
                 size_t size, int flags);
 void tcp_release_cb(struct sock *sk);
index ac2a404c73eb83fc100ed0a2d3bee40ad5a2044c..742f68c9c84a2b73f2df6069ad24469f0c233088 100644 (file)
@@ -2500,6 +2500,12 @@ int sock_no_sendmsg(struct socket *sock, struct msghdr *m, size_t len)
 }
 EXPORT_SYMBOL(sock_no_sendmsg);
 
+int sock_no_sendmsg_locked(struct sock *sk, struct msghdr *m, size_t len)
+{
+       return -EOPNOTSUPP;
+}
+EXPORT_SYMBOL(sock_no_sendmsg_locked);
+
 int sock_no_recvmsg(struct socket *sock, struct msghdr *m, size_t len,
                    int flags)
 {
@@ -2528,6 +2534,22 @@ ssize_t sock_no_sendpage(struct socket *sock, struct page *page, int offset, siz
 }
 EXPORT_SYMBOL(sock_no_sendpage);
 
+ssize_t sock_no_sendpage_locked(struct sock *sk, struct page *page,
+                               int offset, size_t size, int flags)
+{
+       ssize_t res;
+       struct msghdr msg = {.msg_flags = flags};
+       struct kvec iov;
+       char *kaddr = kmap(page);
+
+       iov.iov_base = kaddr + offset;
+       iov.iov_len = size;
+       res = kernel_sendmsg_locked(sk, &msg, &iov, 1, size);
+       kunmap(page);
+       return res;
+}
+EXPORT_SYMBOL(sock_no_sendpage_locked);
+
 /*
  *     Default Socket Callbacks
  */
index 5ce44fb7d49885ffe1496248b4cdfa2eae531c1f..f0103ffe1cdbd5c0b9fb5019546e44be9e3de5b4 100644 (file)
@@ -944,6 +944,8 @@ const struct proto_ops inet_stream_ops = {
        .sendpage          = inet_sendpage,
        .splice_read       = tcp_splice_read,
        .read_sock         = tcp_read_sock,
+       .sendmsg_locked    = tcp_sendmsg_locked,
+       .sendpage_locked   = tcp_sendpage_locked,
        .peek_len          = tcp_peek_len,
 #ifdef CONFIG_COMPAT
        .compat_setsockopt = compat_sock_common_setsockopt,
index 5326b50a345060b7aabcc1fdf2f01329f1515ed2..9dd6f4dba9b136153054e2c5a3db59c695fc9b0a 100644 (file)
@@ -1046,23 +1046,29 @@ out_err:
 }
 EXPORT_SYMBOL_GPL(do_tcp_sendpages);
 
-int tcp_sendpage(struct sock *sk, struct page *page, int offset,
-                size_t size, int flags)
+int tcp_sendpage_locked(struct sock *sk, struct page *page, int offset,
+                       size_t size, int flags)
 {
-       ssize_t res;
-
        if (!(sk->sk_route_caps & NETIF_F_SG) ||
            !sk_check_csum_caps(sk))
                return sock_no_sendpage(sk->sk_socket, page, offset, size,
                                        flags);
 
-       lock_sock(sk);
-
        tcp_rate_check_app_limited(sk);  /* is sending application-limited? */
 
-       res = do_tcp_sendpages(sk, page, offset, size, flags);
+       return do_tcp_sendpages(sk, page, offset, size, flags);
+}
+
+int tcp_sendpage(struct sock *sk, struct page *page, int offset,
+                size_t size, int flags)
+{
+       int ret;
+
+       lock_sock(sk);
+       ret = tcp_sendpage_locked(sk, page, offset, size, flags);
        release_sock(sk);
-       return res;
+
+       return ret;
 }
 EXPORT_SYMBOL(tcp_sendpage);
 
@@ -1156,7 +1162,7 @@ static int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg,
        return err;
 }
 
-int tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
+int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
 {
        struct tcp_sock *tp = tcp_sk(sk);
        struct sk_buff *skb;
@@ -1167,8 +1173,6 @@ int tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
        bool sg;
        long timeo;
 
-       lock_sock(sk);
-
        flags = msg->msg_flags;
        if (unlikely(flags & MSG_FASTOPEN || inet_sk(sk)->defer_connect)) {
                err = tcp_sendmsg_fastopen(sk, msg, &copied_syn, size);
@@ -1377,7 +1381,6 @@ out:
                tcp_push(sk, flags, mss_now, tp->nonagle, size_goal);
        }
 out_nopush:
-       release_sock(sk);
        return copied + copied_syn;
 
 do_fault:
@@ -1401,9 +1404,19 @@ out_err:
                sk->sk_write_space(sk);
                tcp_chrono_stop(sk, TCP_CHRONO_SNDBUF_LIMITED);
        }
-       release_sock(sk);
        return err;
 }
+
+int tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
+{
+       int ret;
+
+       lock_sock(sk);
+       ret = tcp_sendmsg_locked(sk, msg, size);
+       release_sock(sk);
+
+       return ret;
+}
 EXPORT_SYMBOL(tcp_sendmsg);
 
 /*
index cb0fdf799f40d8f64660ce37ea6dc95ffdef3479..b332d1e8e4e4f6e9ec3517062d0ed3879f6e29d0 100644 (file)
@@ -652,6 +652,20 @@ int kernel_sendmsg(struct socket *sock, struct msghdr *msg,
 }
 EXPORT_SYMBOL(kernel_sendmsg);
 
+int kernel_sendmsg_locked(struct sock *sk, struct msghdr *msg,
+                         struct kvec *vec, size_t num, size_t size)
+{
+       struct socket *sock = sk->sk_socket;
+
+       if (!sock->ops->sendmsg_locked)
+               sock_no_sendmsg_locked(sk, msg, size);
+
+       iov_iter_kvec(&msg->msg_iter, WRITE | ITER_KVEC, vec, num, size);
+
+       return sock->ops->sendmsg_locked(sk, msg, msg_data_left(msg));
+}
+EXPORT_SYMBOL(kernel_sendmsg_locked);
+
 static bool skb_is_err_queue(const struct sk_buff *skb)
 {
        /* pkt_type of skbs enqueued on the error queue are set to
@@ -3376,6 +3390,19 @@ int kernel_sendpage(struct socket *sock, struct page *page, int offset,
 }
 EXPORT_SYMBOL(kernel_sendpage);
 
+int kernel_sendpage_locked(struct sock *sk, struct page *page, int offset,
+                          size_t size, int flags)
+{
+       struct socket *sock = sk->sk_socket;
+
+       if (sock->ops->sendpage_locked)
+               return sock->ops->sendpage_locked(sk, page, offset, size,
+                                                 flags);
+
+       return sock_no_sendpage_locked(sk, page, offset, size, flags);
+}
+EXPORT_SYMBOL(kernel_sendpage_locked);
+
 int kernel_sock_ioctl(struct socket *sock, int cmd, unsigned long arg)
 {
        mm_segment_t oldfs = get_fs();