AF_VSOCK: Shrink the area influenced by prepare_to_wait
authorClaudio Imbrenda <imbrenda@linux.vnet.ibm.com>
Tue, 22 Mar 2016 16:05:52 +0000 (17:05 +0100)
committerGreg Kroah-Hartman <gregkh@linuxfoundation.org>
Thu, 30 Nov 2017 08:37:19 +0000 (08:37 +0000)
commit f7f9b5e7f8eccfd68ffa7b8d74b07c478bb9e7f0 upstream.

When a thread is prepared for waiting by calling prepare_to_wait, sleeping
is not allowed until either the wait has taken place or finish_wait has
been called.  The existing code in af_vsock imposed unnecessary no-sleep
assumptions to a broad list of backend functions.
This patch shrinks the influence of prepare_to_wait to the area where it
is strictly needed, therefore relaxing the no-sleep restriction there.

Signed-off-by: Claudio Imbrenda <imbrenda@linux.vnet.ibm.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Cc: "Jorgen S. Hansen" <jhansen@vmware.com>
Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
net/vmw_vsock/af_vsock.c

index 9b5bd6d142dc4c1daf0acacb3f8abbf3b805b275..b5f1221f48d4859156aa640066e1fd80cf2927dc 100644 (file)
@@ -1209,10 +1209,14 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
 
                if (signal_pending(current)) {
                        err = sock_intr_errno(timeout);
-                       goto out_wait_error;
+                       sk->sk_state = SS_UNCONNECTED;
+                       sock->state = SS_UNCONNECTED;
+                       goto out_wait;
                } else if (timeout == 0) {
                        err = -ETIMEDOUT;
-                       goto out_wait_error;
+                       sk->sk_state = SS_UNCONNECTED;
+                       sock->state = SS_UNCONNECTED;
+                       goto out_wait;
                }
 
                prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
@@ -1220,20 +1224,17 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
 
        if (sk->sk_err) {
                err = -sk->sk_err;
-               goto out_wait_error;
-       } else
+               sk->sk_state = SS_UNCONNECTED;
+               sock->state = SS_UNCONNECTED;
+       } else {
                err = 0;
+       }
 
 out_wait:
        finish_wait(sk_sleep(sk), &wait);
 out:
        release_sock(sk);
        return err;
-
-out_wait_error:
-       sk->sk_state = SS_UNCONNECTED;
-       sock->state = SS_UNCONNECTED;
-       goto out_wait;
 }
 
 static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
@@ -1270,18 +1271,20 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
               listener->sk_err == 0) {
                release_sock(listener);
                timeout = schedule_timeout(timeout);
+               finish_wait(sk_sleep(listener), &wait);
                lock_sock(listener);
 
                if (signal_pending(current)) {
                        err = sock_intr_errno(timeout);
-                       goto out_wait;
+                       goto out;
                } else if (timeout == 0) {
                        err = -EAGAIN;
-                       goto out_wait;
+                       goto out;
                }
 
                prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
        }
+       finish_wait(sk_sleep(listener), &wait);
 
        if (listener->sk_err)
                err = -listener->sk_err;
@@ -1301,19 +1304,15 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
                 */
                if (err) {
                        vconnected->rejected = true;
-                       release_sock(connected);
-                       sock_put(connected);
-                       goto out_wait;
+               } else {
+                       newsock->state = SS_CONNECTED;
+                       sock_graft(connected, newsock);
                }
 
-               newsock->state = SS_CONNECTED;
-               sock_graft(connected, newsock);
                release_sock(connected);
                sock_put(connected);
        }
 
-out_wait:
-       finish_wait(sk_sleep(listener), &wait);
 out:
        release_sock(listener);
        return err;
@@ -1557,11 +1556,11 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
        if (err < 0)
                goto out;
 
-       prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
 
        while (total_written < len) {
                ssize_t written;
 
+               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
                while (vsock_stream_has_space(vsk) == 0 &&
                       sk->sk_err == 0 &&
                       !(sk->sk_shutdown & SEND_SHUTDOWN) &&
@@ -1570,27 +1569,33 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
                        /* Don't wait for non-blocking sockets. */
                        if (timeout == 0) {
                                err = -EAGAIN;
-                               goto out_wait;
+                               finish_wait(sk_sleep(sk), &wait);
+                               goto out_err;
                        }
 
                        err = transport->notify_send_pre_block(vsk, &send_data);
-                       if (err < 0)
-                               goto out_wait;
+                       if (err < 0) {
+                               finish_wait(sk_sleep(sk), &wait);
+                               goto out_err;
+                       }
 
                        release_sock(sk);
                        timeout = schedule_timeout(timeout);
                        lock_sock(sk);
                        if (signal_pending(current)) {
                                err = sock_intr_errno(timeout);
-                               goto out_wait;
+                               finish_wait(sk_sleep(sk), &wait);
+                               goto out_err;
                        } else if (timeout == 0) {
                                err = -EAGAIN;
-                               goto out_wait;
+                               finish_wait(sk_sleep(sk), &wait);
+                               goto out_err;
                        }
 
                        prepare_to_wait(sk_sleep(sk), &wait,
                                        TASK_INTERRUPTIBLE);
                }
+               finish_wait(sk_sleep(sk), &wait);
 
                /* These checks occur both as part of and after the loop
                 * conditional since we need to check before and after
@@ -1598,16 +1603,16 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
                 */
                if (sk->sk_err) {
                        err = -sk->sk_err;
-                       goto out_wait;
+                       goto out_err;
                } else if ((sk->sk_shutdown & SEND_SHUTDOWN) ||
                           (vsk->peer_shutdown & RCV_SHUTDOWN)) {
                        err = -EPIPE;
-                       goto out_wait;
+                       goto out_err;
                }
 
                err = transport->notify_send_pre_enqueue(vsk, &send_data);
                if (err < 0)
-                       goto out_wait;
+                       goto out_err;
 
                /* Note that enqueue will only write as many bytes as are free
                 * in the produce queue, so we don't need to ensure len is
@@ -1620,7 +1625,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
                                len - total_written);
                if (written < 0) {
                        err = -ENOMEM;
-                       goto out_wait;
+                       goto out_err;
                }
 
                total_written += written;
@@ -1628,14 +1633,13 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
                err = transport->notify_send_post_enqueue(
                                vsk, written, &send_data);
                if (err < 0)
-                       goto out_wait;
+                       goto out_err;
 
        }
 
-out_wait:
+out_err:
        if (total_written > 0)
                err = total_written;
-       finish_wait(sk_sleep(sk), &wait);
 out:
        release_sock(sk);
        return err;
@@ -1716,21 +1720,61 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
        if (err < 0)
                goto out;
 
-       prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
 
        while (1) {
-               s64 ready = vsock_stream_has_data(vsk);
+               s64 ready;
 
-               if (ready < 0) {
-                       /* Invalid queue pair content. XXX This should be
-                        * changed to a connection reset in a later change.
-                        */
+               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+               ready = vsock_stream_has_data(vsk);
 
-                       err = -ENOMEM;
-                       goto out_wait;
-               } else if (ready > 0) {
+               if (ready == 0) {
+                       if (sk->sk_err != 0 ||
+                           (sk->sk_shutdown & RCV_SHUTDOWN) ||
+                           (vsk->peer_shutdown & SEND_SHUTDOWN)) {
+                               finish_wait(sk_sleep(sk), &wait);
+                               break;
+                       }
+                       /* Don't wait for non-blocking sockets. */
+                       if (timeout == 0) {
+                               err = -EAGAIN;
+                               finish_wait(sk_sleep(sk), &wait);
+                               break;
+                       }
+
+                       err = transport->notify_recv_pre_block(
+                                       vsk, target, &recv_data);
+                       if (err < 0) {
+                               finish_wait(sk_sleep(sk), &wait);
+                               break;
+                       }
+                       release_sock(sk);
+                       timeout = schedule_timeout(timeout);
+                       lock_sock(sk);
+
+                       if (signal_pending(current)) {
+                               err = sock_intr_errno(timeout);
+                               finish_wait(sk_sleep(sk), &wait);
+                               break;
+                       } else if (timeout == 0) {
+                               err = -EAGAIN;
+                               finish_wait(sk_sleep(sk), &wait);
+                               break;
+                       }
+               } else {
                        ssize_t read;
 
+                       finish_wait(sk_sleep(sk), &wait);
+
+                       if (ready < 0) {
+                               /* Invalid queue pair content. XXX This should
+                               * be changed to a connection reset in a later
+                               * change.
+                               */
+
+                               err = -ENOMEM;
+                               goto out;
+                       }
+
                        err = transport->notify_recv_pre_dequeue(
                                        vsk, target, &recv_data);
                        if (err < 0)
@@ -1750,42 +1794,12 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
                                        vsk, target, read,
                                        !(flags & MSG_PEEK), &recv_data);
                        if (err < 0)
-                               goto out_wait;
+                               goto out;
 
                        if (read >= target || flags & MSG_PEEK)
                                break;
 
                        target -= read;
-               } else {
-                       if (sk->sk_err != 0 || (sk->sk_shutdown & RCV_SHUTDOWN)
-                           || (vsk->peer_shutdown & SEND_SHUTDOWN)) {
-                               break;
-                       }
-                       /* Don't wait for non-blocking sockets. */
-                       if (timeout == 0) {
-                               err = -EAGAIN;
-                               break;
-                       }
-
-                       err = transport->notify_recv_pre_block(
-                                       vsk, target, &recv_data);
-                       if (err < 0)
-                               break;
-
-                       release_sock(sk);
-                       timeout = schedule_timeout(timeout);
-                       lock_sock(sk);
-
-                       if (signal_pending(current)) {
-                               err = sock_intr_errno(timeout);
-                               break;
-                       } else if (timeout == 0) {
-                               err = -EAGAIN;
-                               break;
-                       }
-
-                       prepare_to_wait(sk_sleep(sk), &wait,
-                                       TASK_INTERRUPTIBLE);
                }
        }
 
@@ -1797,8 +1811,6 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
        if (copied > 0)
                err = copied;
 
-out_wait:
-       finish_wait(sk_sleep(sk), &wait);
 out:
        release_sock(sk);
        return err;