rtnetlink: protect handler table with rcu
authorFlorian Westphal <fw@strlen.de>
Wed, 9 Aug 2017 18:41:51 +0000 (20:41 +0200)
committerDavid S. Miller <davem@davemloft.net>
Wed, 9 Aug 2017 23:57:38 +0000 (16:57 -0700)
Note that netlink dumps still acquire rtnl mutex via the netlink
dump infrastructure.

Signed-off-by: Florian Westphal <fw@strlen.de>
Reviewed-by: Hannes Frederic Sowa <hannes@stressinduktion.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/core/rtnetlink.c

index be01d8e48661fd43fcd5812040c8e3f2781ada95..d45946177bc89e1c178fb3e9816f884c76e1a68c 100644 (file)
@@ -126,7 +126,7 @@ bool lockdep_rtnl_is_held(void)
 EXPORT_SYMBOL(lockdep_rtnl_is_held);
 #endif /* #ifdef CONFIG_PROVE_LOCKING */
 
-static struct rtnl_link *rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
+static struct rtnl_link __rcu *rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
 static refcount_t rtnl_msg_handlers_ref[RTNL_FAMILY_MAX + 1];
 
 static inline int rtm_msgindex(int msgtype)
@@ -143,36 +143,6 @@ static inline int rtm_msgindex(int msgtype)
        return msgindex;
 }
 
-static rtnl_doit_func rtnl_get_doit(int protocol, int msgindex)
-{
-       struct rtnl_link *tab;
-
-       if (protocol <= RTNL_FAMILY_MAX)
-               tab = rtnl_msg_handlers[protocol];
-       else
-               tab = NULL;
-
-       if (tab == NULL || tab[msgindex].doit == NULL)
-               tab = rtnl_msg_handlers[PF_UNSPEC];
-
-       return tab[msgindex].doit;
-}
-
-static rtnl_dumpit_func rtnl_get_dumpit(int protocol, int msgindex)
-{
-       struct rtnl_link *tab;
-
-       if (protocol <= RTNL_FAMILY_MAX)
-               tab = rtnl_msg_handlers[protocol];
-       else
-               tab = NULL;
-
-       if (tab == NULL || tab[msgindex].dumpit == NULL)
-               tab = rtnl_msg_handlers[PF_UNSPEC];
-
-       return tab[msgindex].dumpit;
-}
-
 /**
  * __rtnl_register - Register a rtnetlink message type
  * @protocol: Protocol family or PF_UNSPEC
@@ -201,18 +171,17 @@ int __rtnl_register(int protocol, int msgtype,
        BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
        msgindex = rtm_msgindex(msgtype);
 
-       tab = rtnl_msg_handlers[protocol];
+       tab = rcu_dereference(rtnl_msg_handlers[protocol]);
        if (tab == NULL) {
                tab = kcalloc(RTM_NR_MSGTYPES, sizeof(*tab), GFP_KERNEL);
                if (tab == NULL)
                        return -ENOBUFS;
 
-               rtnl_msg_handlers[protocol] = tab;
+               rcu_assign_pointer(rtnl_msg_handlers[protocol], tab);
        }
 
        if (doit)
                tab[msgindex].doit = doit;
-
        if (dumpit)
                tab[msgindex].dumpit = dumpit;
 
@@ -249,16 +218,22 @@ EXPORT_SYMBOL_GPL(rtnl_register);
  */
 int rtnl_unregister(int protocol, int msgtype)
 {
+       struct rtnl_link *handlers;
        int msgindex;
 
        BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
        msgindex = rtm_msgindex(msgtype);
 
-       if (rtnl_msg_handlers[protocol] == NULL)
+       rtnl_lock();
+       handlers = rtnl_dereference(rtnl_msg_handlers[protocol]);
+       if (!handlers) {
+               rtnl_unlock();
                return -ENOENT;
+       }
 
-       rtnl_msg_handlers[protocol][msgindex].doit = NULL;
-       rtnl_msg_handlers[protocol][msgindex].dumpit = NULL;
+       handlers[msgindex].doit = NULL;
+       handlers[msgindex].dumpit = NULL;
+       rtnl_unlock();
 
        return 0;
 }
@@ -278,10 +253,12 @@ void rtnl_unregister_all(int protocol)
        BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
 
        rtnl_lock();
-       handlers = rtnl_msg_handlers[protocol];
-       rtnl_msg_handlers[protocol] = NULL;
+       handlers = rtnl_dereference(rtnl_msg_handlers[protocol]);
+       RCU_INIT_POINTER(rtnl_msg_handlers[protocol], NULL);
        rtnl_unlock();
 
+       synchronize_net();
+
        while (refcount_read(&rtnl_msg_handlers_ref[protocol]) > 0)
                schedule();
        kfree(handlers);
@@ -2820,11 +2797,13 @@ static u16 rtnl_calcit(struct sk_buff *skb, struct nlmsghdr *nlh)
         * traverse the list of net devices and compute the minimum
         * buffer size based upon the filter mask.
         */
-       list_for_each_entry(dev, &net->dev_base_head, dev_list) {
+       rcu_read_lock();
+       for_each_netdev_rcu(net, dev) {
                min_ifinfo_dump_size = max_t(u16, min_ifinfo_dump_size,
                                             if_nlmsg_size(dev,
                                                           ext_filter_mask));
        }
+       rcu_read_unlock();
 
        return nlmsg_total_size(min_ifinfo_dump_size);
 }
@@ -2836,19 +2815,29 @@ static int rtnl_dump_all(struct sk_buff *skb, struct netlink_callback *cb)
 
        if (s_idx == 0)
                s_idx = 1;
+
        for (idx = 1; idx <= RTNL_FAMILY_MAX; idx++) {
                int type = cb->nlh->nlmsg_type-RTM_BASE;
+               struct rtnl_link *handlers;
+               rtnl_dumpit_func dumpit;
+
                if (idx < s_idx || idx == PF_PACKET)
                        continue;
-               if (rtnl_msg_handlers[idx] == NULL ||
-                   rtnl_msg_handlers[idx][type].dumpit == NULL)
+
+               handlers = rtnl_dereference(rtnl_msg_handlers[idx]);
+               if (!handlers)
                        continue;
+
+               dumpit = READ_ONCE(handlers[type].dumpit);
+               if (!dumpit)
+                       continue;
+
                if (idx > s_idx) {
                        memset(&cb->args[0], 0, sizeof(cb->args));
                        cb->prev_seq = 0;
                        cb->seq = 0;
                }
-               if (rtnl_msg_handlers[idx][type].dumpit(skb, cb))
+               if (dumpit(skb, cb))
                        break;
        }
        cb->family = idx;
@@ -4151,11 +4140,12 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
                             struct netlink_ext_ack *extack)
 {
        struct net *net = sock_net(skb->sk);
+       struct rtnl_link *handlers;
+       int err = -EOPNOTSUPP;
        rtnl_doit_func doit;
        int kind;
        int family;
        int type;
-       int err;
 
        type = nlh->nlmsg_type;
        if (type > RTM_MAX)
@@ -4173,23 +4163,40 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
        if (kind != 2 && !netlink_net_capable(skb, CAP_NET_ADMIN))
                return -EPERM;
 
+       if (family > ARRAY_SIZE(rtnl_msg_handlers))
+               family = PF_UNSPEC;
+
+       rcu_read_lock();
+       handlers = rcu_dereference(rtnl_msg_handlers[family]);
+       if (!handlers) {
+               family = PF_UNSPEC;
+               handlers = rcu_dereference(rtnl_msg_handlers[family]);
+       }
+
        if (kind == 2 && nlh->nlmsg_flags&NLM_F_DUMP) {
                struct sock *rtnl;
                rtnl_dumpit_func dumpit;
                u16 min_dump_alloc = 0;
 
-               rtnl_lock();
+               dumpit = READ_ONCE(handlers[type].dumpit);
+               if (!dumpit) {
+                       family = PF_UNSPEC;
+                       handlers = rcu_dereference(rtnl_msg_handlers[PF_UNSPEC]);
+                       if (!handlers)
+                               goto err_unlock;
 
-               dumpit = rtnl_get_dumpit(family, type);
-               if (dumpit == NULL)
-                       goto err_unlock;
+                       dumpit = READ_ONCE(handlers[type].dumpit);
+                       if (!dumpit)
+                               goto err_unlock;
+               }
 
                refcount_inc(&rtnl_msg_handlers_ref[family]);
 
                if (type == RTM_GETLINK)
                        min_dump_alloc = rtnl_calcit(skb, nlh);
 
-               __rtnl_unlock();
+               rcu_read_unlock();
+
                rtnl = net->rtnl;
                {
                        struct netlink_dump_control c = {
@@ -4202,18 +4209,20 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
                return err;
        }
 
-       rtnl_lock();
-       doit = rtnl_get_doit(family, type);
-       if (doit == NULL)
-               goto err_unlock;
+       rcu_read_unlock();
 
-       err = doit(skb, nlh, extack);
+       rtnl_lock();
+       handlers = rtnl_dereference(rtnl_msg_handlers[family]);
+       if (handlers) {
+               doit = READ_ONCE(handlers[type].doit);
+               if (doit)
+                       err = doit(skb, nlh, extack);
+       }
        rtnl_unlock();
-
        return err;
 
 err_unlock:
-       rtnl_unlock();
+       rcu_read_unlock();
        return -EOPNOTSUPP;
 }