net: fix sleeping for sk_wait_event()
authorWANG Cong <xiyou.wangcong@gmail.com>
Fri, 11 Nov 2016 18:20:50 +0000 (10:20 -0800)
committerDavid S. Miller <davem@davemloft.net>
Mon, 14 Nov 2016 18:17:21 +0000 (13:17 -0500)
Similar to commit 14135f30e33c ("inet: fix sleeping inside inet_wait_for_connect()"),
sk_wait_event() needs to fix too, because release_sock() is blocking,
it changes the process state back to running after sleep, which breaks
the previous prepare_to_wait().

Switch to the new wait API.

Cc: Eric Dumazet <eric.dumazet@gmail.com>
Cc: Peter Zijlstra <peterz@infradead.org>
Signed-off-by: Cong Wang <xiyou.wangcong@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
crypto/algif_aead.c
crypto/algif_skcipher.c
include/net/sock.h
net/core/sock.c
net/core/stream.c
net/decnet/af_decnet.c
net/llc/af_llc.c
net/phonet/pep.c
net/tipc/socket.c
net/vmw_vsock/virtio_transport_common.c

index 80a0f1a7855181930afa0a545f5bf79584141e5c..8948392c0525db3d6d87bbc53d97f252335381e7 100644 (file)
@@ -132,28 +132,27 @@ static void aead_wmem_wakeup(struct sock *sk)
 
 static int aead_wait_for_data(struct sock *sk, unsigned flags)
 {
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
        struct alg_sock *ask = alg_sk(sk);
        struct aead_ctx *ctx = ask->private;
        long timeout;
-       DEFINE_WAIT(wait);
        int err = -ERESTARTSYS;
 
        if (flags & MSG_DONTWAIT)
                return -EAGAIN;
 
        sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-
+       add_wait_queue(sk_sleep(sk), &wait);
        for (;;) {
                if (signal_pending(current))
                        break;
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
                timeout = MAX_SCHEDULE_TIMEOUT;
-               if (sk_wait_event(sk, &timeout, !ctx->more)) {
+               if (sk_wait_event(sk, &timeout, !ctx->more, &wait)) {
                        err = 0;
                        break;
                }
        }
-       finish_wait(sk_sleep(sk), &wait);
+       remove_wait_queue(sk_sleep(sk), &wait);
 
        sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 
index 28556fce42671e2f182d5239d3dc6468e5b1d970..1e38aaa8303ea831aef6ba0f7eedde2ac5ee4810 100644 (file)
@@ -199,26 +199,26 @@ static void skcipher_free_sgl(struct sock *sk)
 
 static int skcipher_wait_for_wmem(struct sock *sk, unsigned flags)
 {
-       long timeout;
-       DEFINE_WAIT(wait);
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
        int err = -ERESTARTSYS;
+       long timeout;
 
        if (flags & MSG_DONTWAIT)
                return -EAGAIN;
 
        sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk);
 
+       add_wait_queue(sk_sleep(sk), &wait);
        for (;;) {
                if (signal_pending(current))
                        break;
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
                timeout = MAX_SCHEDULE_TIMEOUT;
-               if (sk_wait_event(sk, &timeout, skcipher_writable(sk))) {
+               if (sk_wait_event(sk, &timeout, skcipher_writable(sk), &wait)) {
                        err = 0;
                        break;
                }
        }
-       finish_wait(sk_sleep(sk), &wait);
+       remove_wait_queue(sk_sleep(sk), &wait);
 
        return err;
 }
@@ -242,10 +242,10 @@ static void skcipher_wmem_wakeup(struct sock *sk)
 
 static int skcipher_wait_for_data(struct sock *sk, unsigned flags)
 {
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
        struct alg_sock *ask = alg_sk(sk);
        struct skcipher_ctx *ctx = ask->private;
        long timeout;
-       DEFINE_WAIT(wait);
        int err = -ERESTARTSYS;
 
        if (flags & MSG_DONTWAIT) {
@@ -254,17 +254,17 @@ static int skcipher_wait_for_data(struct sock *sk, unsigned flags)
 
        sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 
+       add_wait_queue(sk_sleep(sk), &wait);
        for (;;) {
                if (signal_pending(current))
                        break;
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
                timeout = MAX_SCHEDULE_TIMEOUT;
-               if (sk_wait_event(sk, &timeout, ctx->used)) {
+               if (sk_wait_event(sk, &timeout, ctx->used, &wait)) {
                        err = 0;
                        break;
                }
        }
-       finish_wait(sk_sleep(sk), &wait);
+       remove_wait_queue(sk_sleep(sk), &wait);
 
        sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 
index cf617ee16723f2096941548a100832154e359c36..9d905ed0cd250203f1fbe282186d7e0d4de8f518 100644 (file)
@@ -915,14 +915,16 @@ static inline void sock_rps_reset_rxhash(struct sock *sk)
 #endif
 }
 
-#define sk_wait_event(__sk, __timeo, __condition)                      \
+#define sk_wait_event(__sk, __timeo, __condition, __wait)              \
        ({      int __rc;                                               \
                release_sock(__sk);                                     \
                __rc = __condition;                                     \
                if (!__rc) {                                            \
-                       *(__timeo) = schedule_timeout(*(__timeo));      \
+                       *(__timeo) = wait_woken(__wait,                 \
+                                               TASK_INTERRUPTIBLE,     \
+                                               *(__timeo));            \
                }                                                       \
-               sched_annotate_sleep();                                         \
+               sched_annotate_sleep();                                 \
                lock_sock(__sk);                                        \
                __rc = __condition;                                     \
                __rc;                                                   \
index 40dbc13453f9a244a43cf352c8be0301b67e36f2..0397928dfdc21d1a6b77c321c24533e6c42013da 100644 (file)
@@ -2078,14 +2078,14 @@ void __sk_flush_backlog(struct sock *sk)
  */
 int sk_wait_data(struct sock *sk, long *timeo, const struct sk_buff *skb)
 {
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
        int rc;
-       DEFINE_WAIT(wait);
 
-       prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+       add_wait_queue(sk_sleep(sk), &wait);
        sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-       rc = sk_wait_event(sk, timeo, skb_peek_tail(&sk->sk_receive_queue) != skb);
+       rc = sk_wait_event(sk, timeo, skb_peek_tail(&sk->sk_receive_queue) != skb, &wait);
        sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-       finish_wait(sk_sleep(sk), &wait);
+       remove_wait_queue(sk_sleep(sk), &wait);
        return rc;
 }
 EXPORT_SYMBOL(sk_wait_data);
index 1086c8b280a868101df7b44691eccac40bd7b3e1..f575bcf64af2c32f684f178ea553338b00a9a051 100644 (file)
@@ -53,8 +53,8 @@ void sk_stream_write_space(struct sock *sk)
  */
 int sk_stream_wait_connect(struct sock *sk, long *timeo_p)
 {
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
        struct task_struct *tsk = current;
-       DEFINE_WAIT(wait);
        int done;
 
        do {
@@ -68,13 +68,13 @@ int sk_stream_wait_connect(struct sock *sk, long *timeo_p)
                if (signal_pending(tsk))
                        return sock_intr_errno(*timeo_p);
 
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+               add_wait_queue(sk_sleep(sk), &wait);
                sk->sk_write_pending++;
                done = sk_wait_event(sk, timeo_p,
                                     !sk->sk_err &&
                                     !((1 << sk->sk_state) &
-                                      ~(TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)));
-               finish_wait(sk_sleep(sk), &wait);
+                                      ~(TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)), &wait);
+               remove_wait_queue(sk_sleep(sk), &wait);
                sk->sk_write_pending--;
        } while (!done);
        return 0;
@@ -94,16 +94,16 @@ static inline int sk_stream_closing(struct sock *sk)
 void sk_stream_wait_close(struct sock *sk, long timeout)
 {
        if (timeout) {
-               DEFINE_WAIT(wait);
+               DEFINE_WAIT_FUNC(wait, woken_wake_function);
+
+               add_wait_queue(sk_sleep(sk), &wait);
 
                do {
-                       prepare_to_wait(sk_sleep(sk), &wait,
-                                       TASK_INTERRUPTIBLE);
-                       if (sk_wait_event(sk, &timeout, !sk_stream_closing(sk)))
+                       if (sk_wait_event(sk, &timeout, !sk_stream_closing(sk), &wait))
                                break;
                } while (!signal_pending(current) && timeout);
 
-               finish_wait(sk_sleep(sk), &wait);
+               remove_wait_queue(sk_sleep(sk), &wait);
        }
 }
 EXPORT_SYMBOL(sk_stream_wait_close);
@@ -119,16 +119,16 @@ int sk_stream_wait_memory(struct sock *sk, long *timeo_p)
        long vm_wait = 0;
        long current_timeo = *timeo_p;
        bool noblock = (*timeo_p ? false : true);
-       DEFINE_WAIT(wait);
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
 
        if (sk_stream_memory_free(sk))
                current_timeo = vm_wait = (prandom_u32() % (HZ / 5)) + 2;
 
+       add_wait_queue(sk_sleep(sk), &wait);
+
        while (1) {
                sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk);
 
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
-
                if (sk->sk_err || (sk->sk_shutdown & SEND_SHUTDOWN))
                        goto do_error;
                if (!*timeo_p) {
@@ -147,7 +147,7 @@ int sk_stream_wait_memory(struct sock *sk, long *timeo_p)
                sk_wait_event(sk, &current_timeo, sk->sk_err ||
                                                  (sk->sk_shutdown & SEND_SHUTDOWN) ||
                                                  (sk_stream_memory_free(sk) &&
-                                                 !vm_wait));
+                                                 !vm_wait), &wait);
                sk->sk_write_pending--;
 
                if (vm_wait) {
@@ -161,7 +161,7 @@ int sk_stream_wait_memory(struct sock *sk, long *timeo_p)
                *timeo_p = current_timeo;
        }
 out:
-       finish_wait(sk_sleep(sk), &wait);
+       remove_wait_queue(sk_sleep(sk), &wait);
        return err;
 
 do_error:
index 13d6b1a6e0fc2b0730827d93d154d6464a3e58ec..a90ed67027b0cfa6b8ba8a25fc72b1ccd9f2886b 100644 (file)
@@ -1718,7 +1718,7 @@ static int dn_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
         * See if there is data ready to read, sleep if there isn't
         */
        for(;;) {
-               DEFINE_WAIT(wait);
+               DEFINE_WAIT_FUNC(wait, woken_wake_function);
 
                if (sk->sk_err)
                        goto out;
@@ -1749,11 +1749,11 @@ static int dn_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
                        goto out;
                }
 
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+               add_wait_queue(sk_sleep(sk), &wait);
                sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-               sk_wait_event(sk, &timeo, dn_data_ready(sk, queue, flags, target));
+               sk_wait_event(sk, &timeo, dn_data_ready(sk, queue, flags, target), &wait);
                sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-               finish_wait(sk_sleep(sk), &wait);
+               remove_wait_queue(sk_sleep(sk), &wait);
        }
 
        skb_queue_walk_safe(queue, skb, n) {
@@ -1999,19 +1999,19 @@ static int dn_sendmsg(struct socket *sock, struct msghdr *msg, size_t size)
                 * size.
                 */
                if (dn_queue_too_long(scp, queue, flags)) {
-                       DEFINE_WAIT(wait);
+                       DEFINE_WAIT_FUNC(wait, woken_wake_function);
 
                        if (flags & MSG_DONTWAIT) {
                                err = -EWOULDBLOCK;
                                goto out;
                        }
 
-                       prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+                       add_wait_queue(sk_sleep(sk), &wait);
                        sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
                        sk_wait_event(sk, &timeo,
-                                     !dn_queue_too_long(scp, queue, flags));
+                                     !dn_queue_too_long(scp, queue, flags), &wait);
                        sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-                       finish_wait(sk_sleep(sk), &wait);
+                       remove_wait_queue(sk_sleep(sk), &wait);
                        continue;
                }
 
index db916cf51ffeabd3d5246cb09a9e7227d483f574..5e92963824202823bcb706c54444ad5a6e7d2358 100644 (file)
@@ -532,12 +532,12 @@ out:
 
 static int llc_ui_wait_for_disc(struct sock *sk, long timeout)
 {
-       DEFINE_WAIT(wait);
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
        int rc = 0;
 
+       add_wait_queue(sk_sleep(sk), &wait);
        while (1) {
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
-               if (sk_wait_event(sk, &timeout, sk->sk_state == TCP_CLOSE))
+               if (sk_wait_event(sk, &timeout, sk->sk_state == TCP_CLOSE, &wait))
                        break;
                rc = -ERESTARTSYS;
                if (signal_pending(current))
@@ -547,39 +547,39 @@ static int llc_ui_wait_for_disc(struct sock *sk, long timeout)
                        break;
                rc = 0;
        }
-       finish_wait(sk_sleep(sk), &wait);
+       remove_wait_queue(sk_sleep(sk), &wait);
        return rc;
 }
 
 static bool llc_ui_wait_for_conn(struct sock *sk, long timeout)
 {
-       DEFINE_WAIT(wait);
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
 
+       add_wait_queue(sk_sleep(sk), &wait);
        while (1) {
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
-               if (sk_wait_event(sk, &timeout, sk->sk_state != TCP_SYN_SENT))
+               if (sk_wait_event(sk, &timeout, sk->sk_state != TCP_SYN_SENT, &wait))
                        break;
                if (signal_pending(current) || !timeout)
                        break;
        }
-       finish_wait(sk_sleep(sk), &wait);
+       remove_wait_queue(sk_sleep(sk), &wait);
        return timeout;
 }
 
 static int llc_ui_wait_for_busy_core(struct sock *sk, long timeout)
 {
-       DEFINE_WAIT(wait);
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
        struct llc_sock *llc = llc_sk(sk);
        int rc;
 
+       add_wait_queue(sk_sleep(sk), &wait);
        while (1) {
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
                rc = 0;
                if (sk_wait_event(sk, &timeout,
                                  (sk->sk_shutdown & RCV_SHUTDOWN) ||
                                  (!llc_data_accept_state(llc->state) &&
                                   !llc->remote_busy_flag &&
-                                  !llc->p_flag)))
+                                  !llc->p_flag), &wait))
                        break;
                rc = -ERESTARTSYS;
                if (signal_pending(current))
@@ -588,7 +588,7 @@ static int llc_ui_wait_for_busy_core(struct sock *sk, long timeout)
                if (!timeout)
                        break;
        }
-       finish_wait(sk_sleep(sk), &wait);
+       remove_wait_queue(sk_sleep(sk), &wait);
        return rc;
 }
 
index 850a86cde0b3f6eab5b7aa09f4e6ffa66ccd6ed6..8bad5624a27a9ffdcbf193c4c2f078b4b648b044 100644 (file)
@@ -1167,7 +1167,7 @@ disabled:
        /* Wait until flow control allows TX */
        done = atomic_read(&pn->tx_credits);
        while (!done) {
-               DEFINE_WAIT(wait);
+               DEFINE_WAIT_FUNC(wait, woken_wake_function);
 
                if (!timeo) {
                        err = -EAGAIN;
@@ -1178,10 +1178,9 @@ disabled:
                        goto out;
                }
 
-               prepare_to_wait(sk_sleep(sk), &wait,
-                               TASK_INTERRUPTIBLE);
-               done = sk_wait_event(sk, &timeo, atomic_read(&pn->tx_credits));
-               finish_wait(sk_sleep(sk), &wait);
+               add_wait_queue(sk_sleep(sk), &wait);
+               done = sk_wait_event(sk, &timeo, atomic_read(&pn->tx_credits), &wait);
+               remove_wait_queue(sk_sleep(sk), &wait);
 
                if (sk->sk_state != TCP_ESTABLISHED)
                        goto disabled;
index 149396366e80402fbcfca436749a2489d3aae9a9..22d92f0ec5ac020b52d2c3ed829307e1aec8c175 100644 (file)
@@ -876,9 +876,9 @@ exit:
 
 static int tipc_wait_for_sndmsg(struct socket *sock, long *timeo_p)
 {
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
        struct sock *sk = sock->sk;
        struct tipc_sock *tsk = tipc_sk(sk);
-       DEFINE_WAIT(wait);
        int done;
 
        do {
@@ -892,9 +892,9 @@ static int tipc_wait_for_sndmsg(struct socket *sock, long *timeo_p)
                if (signal_pending(current))
                        return sock_intr_errno(*timeo_p);
 
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
-               done = sk_wait_event(sk, timeo_p, !tsk->link_cong);
-               finish_wait(sk_sleep(sk), &wait);
+               add_wait_queue(sk_sleep(sk), &wait);
+               done = sk_wait_event(sk, timeo_p, !tsk->link_cong, &wait);
+               remove_wait_queue(sk_sleep(sk), &wait);
        } while (!done);
        return 0;
 }
@@ -1031,9 +1031,9 @@ new_mtu:
 
 static int tipc_wait_for_sndpkt(struct socket *sock, long *timeo_p)
 {
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
        struct sock *sk = sock->sk;
        struct tipc_sock *tsk = tipc_sk(sk);
-       DEFINE_WAIT(wait);
        int done;
 
        do {
@@ -1049,12 +1049,12 @@ static int tipc_wait_for_sndpkt(struct socket *sock, long *timeo_p)
                if (signal_pending(current))
                        return sock_intr_errno(*timeo_p);
 
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+               add_wait_queue(sk_sleep(sk), &wait);
                done = sk_wait_event(sk, timeo_p,
                                     (!tsk->link_cong &&
                                      !tsk_conn_cong(tsk)) ||
-                                     !tipc_sk_connected(sk));
-               finish_wait(sk_sleep(sk), &wait);
+                                     !tipc_sk_connected(sk), &wait);
+               remove_wait_queue(sk_sleep(sk), &wait);
        } while (!done);
        return 0;
 }
@@ -1929,8 +1929,8 @@ xmit:
 
 static int tipc_wait_for_connect(struct socket *sock, long *timeo_p)
 {
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
        struct sock *sk = sock->sk;
-       DEFINE_WAIT(wait);
        int done;
 
        do {
@@ -1942,10 +1942,10 @@ static int tipc_wait_for_connect(struct socket *sock, long *timeo_p)
                if (signal_pending(current))
                        return sock_intr_errno(*timeo_p);
 
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+               add_wait_queue(sk_sleep(sk), &wait);
                done = sk_wait_event(sk, timeo_p,
-                                    sk->sk_state != TIPC_CONNECTING);
-               finish_wait(sk_sleep(sk), &wait);
+                                    sk->sk_state != TIPC_CONNECTING, &wait);
+               remove_wait_queue(sk_sleep(sk), &wait);
        } while (!done);
        return 0;
 }
index a53b3a16b4f1f79554bcbc066a49d6a0e848f093..687e9fdb3d672b9e2a36c0255e5dae51ce2d03a0 100644 (file)
@@ -619,17 +619,17 @@ static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
 static void virtio_transport_wait_close(struct sock *sk, long timeout)
 {
        if (timeout) {
-               DEFINE_WAIT(wait);
+               DEFINE_WAIT_FUNC(wait, woken_wake_function);
+
+               add_wait_queue(sk_sleep(sk), &wait);
 
                do {
-                       prepare_to_wait(sk_sleep(sk), &wait,
-                                       TASK_INTERRUPTIBLE);
                        if (sk_wait_event(sk, &timeout,
-                                         sock_flag(sk, SOCK_DONE)))
+                                         sock_flag(sk, SOCK_DONE), &wait))
                                break;
                } while (!signal_pending(current) && timeout);
 
-               finish_wait(sk_sleep(sk), &wait);
+               remove_wait_queue(sk_sleep(sk), &wait);
        }
 }