net: ipmr: fix static mfc/dev leaks on table destruction
[GitHub/mt8127/android_kernel_alcatel_ttab.git] / net / ipv4 / ipmr.c
index 9d9610ae78553895e9f6ccb0ea8260fa4a66ac17..b31553d385bba1faf763dfb771a6f7ba529d6d26 100644 (file)
@@ -136,7 +136,7 @@ static int __ipmr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb,
                              struct mfc_cache *c, struct rtmsg *rtm);
 static void mroute_netlink_event(struct mr_table *mrt, struct mfc_cache *mfc,
                                 int cmd);
-static void mroute_clean_tables(struct mr_table *mrt);
+static void mroute_clean_tables(struct mr_table *mrt, bool all);
 static void ipmr_expire_process(unsigned long arg);
 
 #ifdef CONFIG_IP_MROUTE_MULTIPLE_TABLES
@@ -157,9 +157,12 @@ static struct mr_table *ipmr_get_table(struct net *net, u32 id)
 static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4,
                           struct mr_table **mrt)
 {
-       struct ipmr_result res;
-       struct fib_lookup_arg arg = { .result = &res, };
        int err;
+       struct ipmr_result res;
+       struct fib_lookup_arg arg = {
+               .result = &res,
+               .flags = FIB_LOOKUP_NOREF,
+       };
 
        err = fib_rules_lookup(net->ipv4.mr_rules_ops,
                               flowi4_to_flowi(flp4), 0, &arg);
@@ -345,7 +348,7 @@ static struct mr_table *ipmr_new_table(struct net *net, u32 id)
 static void ipmr_free_table(struct mr_table *mrt)
 {
        del_timer_sync(&mrt->ipmr_expire_timer);
-       mroute_clean_tables(mrt);
+       mroute_clean_tables(mrt, true);
        kfree(mrt);
 }
 
@@ -1196,7 +1199,7 @@ static int ipmr_mfc_add(struct net *net, struct mr_table *mrt,
  *     Close the multicast socket, and clear the vif tables etc
  */
 
-static void mroute_clean_tables(struct mr_table *mrt)
+static void mroute_clean_tables(struct mr_table *mrt, bool all)
 {
        int i;
        LIST_HEAD(list);
@@ -1205,8 +1208,9 @@ static void mroute_clean_tables(struct mr_table *mrt)
        /* Shut down all active vif entries */
 
        for (i = 0; i < mrt->maxvif; i++) {
-               if (!(mrt->vif_table[i].flags & VIFF_STATIC))
-                       vif_delete(mrt, i, 0, &list);
+               if (!all && (mrt->vif_table[i].flags & VIFF_STATIC))
+                       continue;
+               vif_delete(mrt, i, 0, &list);
        }
        unregister_netdevice_many(&list);
 
@@ -1214,7 +1218,7 @@ static void mroute_clean_tables(struct mr_table *mrt)
 
        for (i = 0; i < MFC_LINES; i++) {
                list_for_each_entry_safe(c, next, &mrt->mfc_cache_array[i], list) {
-                       if (c->mfc_flags & MFC_STATIC)
+                       if (!all && (c->mfc_flags & MFC_STATIC))
                                continue;
                        list_del_rcu(&c->list);
                        mroute_netlink_event(mrt, c, RTM_DELROUTE);
@@ -1249,7 +1253,7 @@ static void mrtsock_destruct(struct sock *sk)
                                                    NETCONFA_IFINDEX_ALL,
                                                    net->ipv4.devconf_all);
                        RCU_INIT_POINTER(mrt->mroute_sk, NULL);
-                       mroute_clean_tables(mrt);
+                       mroute_clean_tables(mrt, false);
                }
        }
        rtnl_unlock();
@@ -1658,7 +1662,7 @@ static void ip_encap(struct sk_buff *skb, __be32 saddr, __be32 daddr)
        iph->protocol   =       IPPROTO_IPIP;
        iph->ihl        =       5;
        iph->tot_len    =       htons(skb->len);
-       ip_select_ident(iph, skb_dst(skb), NULL);
+       ip_select_ident(skb, NULL);
        ip_send_check(iph);
 
        memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt));
@@ -1669,8 +1673,8 @@ static inline int ipmr_forward_finish(struct sk_buff *skb)
 {
        struct ip_options *opt = &(IPCB(skb)->opt);
 
-       IP_INC_STATS_BH(dev_net(skb_dst(skb)->dev), IPSTATS_MIB_OUTFORWDATAGRAMS);
-       IP_ADD_STATS_BH(dev_net(skb_dst(skb)->dev), IPSTATS_MIB_OUTOCTETS, skb->len);
+       IP_INC_STATS(dev_net(skb_dst(skb)->dev), IPSTATS_MIB_OUTFORWDATAGRAMS);
+       IP_ADD_STATS(dev_net(skb_dst(skb)->dev), IPSTATS_MIB_OUTOCTETS, skb->len);
 
        if (unlikely(opt->optlen))
                ip_forward_options(skb);
@@ -1732,7 +1736,7 @@ static void ipmr_queue_xmit(struct net *net, struct mr_table *mrt,
                 * to blackhole.
                 */
 
-               IP_INC_STATS_BH(dev_net(dev), IPSTATS_MIB_FRAGFAILS);
+               IP_INC_STATS(dev_net(dev), IPSTATS_MIB_FRAGFAILS);
                ip_rt_put(rt);
                goto out_free;
        }
@@ -2252,13 +2256,14 @@ int ipmr_get_route(struct net *net, struct sk_buff *skb,
 }
 
 static int ipmr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb,
-                           u32 portid, u32 seq, struct mfc_cache *c, int cmd)
+                           u32 portid, u32 seq, struct mfc_cache *c, int cmd,
+                           int flags)
 {
        struct nlmsghdr *nlh;
        struct rtmsg *rtm;
        int err;
 
-       nlh = nlmsg_put(skb, portid, seq, cmd, sizeof(*rtm), NLM_F_MULTI);
+       nlh = nlmsg_put(skb, portid, seq, cmd, sizeof(*rtm), flags);
        if (nlh == NULL)
                return -EMSGSIZE;
 
@@ -2326,7 +2331,7 @@ static void mroute_netlink_event(struct mr_table *mrt, struct mfc_cache *mfc,
        if (skb == NULL)
                goto errout;
 
-       err = ipmr_fill_mroute(mrt, skb, 0, 0, mfc, cmd);
+       err = ipmr_fill_mroute(mrt, skb, 0, 0, mfc, cmd, 0);
        if (err < 0)
                goto errout;
 
@@ -2365,7 +2370,8 @@ static int ipmr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb)
                                if (ipmr_fill_mroute(mrt, skb,
                                                     NETLINK_CB(cb->skb).portid,
                                                     cb->nlh->nlmsg_seq,
-                                                    mfc, RTM_NEWROUTE) < 0)
+                                                    mfc, RTM_NEWROUTE,
+                                                    NLM_F_MULTI) < 0)
                                        goto done;
 next_entry:
                                e++;
@@ -2379,7 +2385,8 @@ next_entry:
                        if (ipmr_fill_mroute(mrt, skb,
                                             NETLINK_CB(cb->skb).portid,
                                             cb->nlh->nlmsg_seq,
-                                            mfc, RTM_NEWROUTE) < 0) {
+                                            mfc, RTM_NEWROUTE,
+                                            NLM_F_MULTI) < 0) {
                                spin_unlock_bh(&mfc_unres_lock);
                                goto done;
                        }