tcp: fix a timewait refcnt race
[GitHub/mt8127/android_kernel_alcatel_ttab.git] / net / ipv4 / inet_timewait_sock.c
index 11a107a5af4f64d7a0e7159079941dc80359b2d4..0fdf45e4c90c8c8475cfd28feb9830d65ebfc204 100644 (file)
@@ -109,7 +109,6 @@ void __inet_twsk_hashdance(struct inet_timewait_sock *tw, struct sock *sk,
        tw->tw_tb = icsk->icsk_bind_hash;
        WARN_ON(!icsk->icsk_bind_hash);
        inet_twsk_add_bind_node(tw, &tw->tw_tb->owners);
-       atomic_inc(&tw->tw_refcnt);
        spin_unlock(&bhead->lock);
 
        spin_lock(lock);
@@ -119,13 +118,22 @@ void __inet_twsk_hashdance(struct inet_timewait_sock *tw, struct sock *sk,
         * Should be done before removing sk from established chain
         * because readers are lockless and search established first.
         */
-       atomic_inc(&tw->tw_refcnt);
        inet_twsk_add_node_rcu(tw, &ehead->twchain);
 
        /* Step 3: Remove SK from established hash. */
        if (__sk_nulls_del_node_init_rcu(sk))
                sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
 
+       /*
+        * Notes :
+        * - We initially set tw_refcnt to 0 in inet_twsk_alloc()
+        * - We add one reference for the bhash link
+        * - We add one reference for the ehash link
+        * - We want this refcnt update done before allowing other
+        *   threads to find this tw in ehash chain.
+        */
+       atomic_add(1 + 1 + 1, &tw->tw_refcnt);
+
        spin_unlock(lock);
 }
 
@@ -157,7 +165,12 @@ struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk, const int stat
                tw->tw_transparent  = inet->transparent;
                tw->tw_prot         = sk->sk_prot_creator;
                twsk_net_set(tw, hold_net(sock_net(sk)));
-               atomic_set(&tw->tw_refcnt, 1);
+               /*
+                * Because we use RCU lookups, we should not set tw_refcnt
+                * to a non null value before everything is setup for this
+                * timewait socket.
+                */
+               atomic_set(&tw->tw_refcnt, 0);
                inet_twsk_dead_node_init(tw);
                __module_get(tw->tw_prot->owner);
        }