netvsc: use ERR_PTR to avoid dereference issues
authorstephen hemminger <stephen@networkplumber.org>
Wed, 19 Jul 2017 18:53:16 +0000 (11:53 -0700)
committerDavid S. Miller <davem@davemloft.net>
Thu, 20 Jul 2017 05:20:05 +0000 (22:20 -0700)
The rndis_filter_device_add function is called both in
probe context and RTNL context,and creates the netvsc_device
inner structure. It is easier to get the RTNL lock annotation
correct if it returns the object directly, rather than implicitly
by updating network device private data.

Signed-off-by: Stephen Hemminger <sthemmin@microsoft.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
drivers/net/hyperv/hyperv_net.h
drivers/net/hyperv/netvsc.c
drivers/net/hyperv/netvsc_drv.c
drivers/net/hyperv/rndis_filter.c

index 5d541a1462c207ba530a5ea99478672223762880..e620374727c8c580ffc3b58f1411fbd9bf50c0c6 100644 (file)
@@ -183,8 +183,8 @@ struct rndis_device {
 /* Interface */
 struct rndis_message;
 struct netvsc_device;
-int netvsc_device_add(struct hv_device *device,
-                     const struct netvsc_device_info *info);
+struct netvsc_device *netvsc_device_add(struct hv_device *device,
+                                       const struct netvsc_device_info *info);
 void netvsc_device_remove(struct hv_device *device);
 int netvsc_send(struct hv_device *device,
                struct hv_netvsc_packet *packet,
@@ -203,8 +203,8 @@ int netvsc_poll(struct napi_struct *napi, int budget);
 bool rndis_filter_opened(const struct netvsc_device *nvdev);
 int rndis_filter_open(struct netvsc_device *nvdev);
 int rndis_filter_close(struct netvsc_device *nvdev);
-int rndis_filter_device_add(struct hv_device *dev,
-                           struct netvsc_device_info *info);
+struct netvsc_device *rndis_filter_device_add(struct hv_device *dev,
+                                             struct netvsc_device_info *info);
 void rndis_filter_update(struct netvsc_device *nvdev);
 void rndis_filter_device_remove(struct hv_device *dev,
                                struct netvsc_device *nvdev);
index e202ec5d6f637a52bfc882a9897ab0ebf79605b3..4a2550559442fa52a584cd9e428939a09cb1cd41 100644 (file)
@@ -29,6 +29,8 @@
 #include <linux/netdevice.h>
 #include <linux/if_ether.h>
 #include <linux/vmalloc.h>
+#include <linux/rtnetlink.h>
+
 #include <asm/sync_bitops.h>
 
 #include "hyperv_net.h"
@@ -1272,8 +1274,8 @@ void netvsc_channel_cb(void *context)
  * netvsc_device_add - Callback when the device belonging to this
  * driver is added
  */
-int netvsc_device_add(struct hv_device *device,
-                     const struct netvsc_device_info *device_info)
+struct netvsc_device *netvsc_device_add(struct hv_device *device,
+                               const struct netvsc_device_info *device_info)
 {
        int i, ret = 0;
        int ring_size = device_info->ring_size;
@@ -1283,7 +1285,7 @@ int netvsc_device_add(struct hv_device *device,
 
        net_device = alloc_net_device();
        if (!net_device)
-               return -ENOMEM;
+               return ERR_PTR(-ENOMEM);
 
        net_device->ring_size = ring_size;
 
@@ -1339,7 +1341,7 @@ int netvsc_device_add(struct hv_device *device,
                goto close;
        }
 
-       return ret;
+       return net_device;
 
 close:
        netif_napi_del(&net_device->chan_table[0].napi);
@@ -1350,6 +1352,5 @@ close:
 cleanup:
        free_netvsc_device(&net_device->rcu);
 
-       return ret;
-
+       return ERR_PTR(ret);
 }
index 82e41c056e539a44a03c275cdf72e0e3191c72d0..0ca8c74143b49126c7fc7d6e4694ed4e67c6279e 100644 (file)
@@ -717,6 +717,7 @@ static int netvsc_set_queues(struct net_device *net, struct hv_device *dev,
                             u32 num_chn)
 {
        struct netvsc_device_info device_info;
+       struct netvsc_device *net_device;
        int ret;
 
        memset(&device_info, 0, sizeof(device_info));
@@ -732,7 +733,8 @@ static int netvsc_set_queues(struct net_device *net, struct hv_device *dev,
        if (ret)
                return ret;
 
-       return rndis_filter_device_add(dev, &device_info);
+       net_device = rndis_filter_device_add(dev, &device_info);
+       return IS_ERR(net_device) ? PTR_ERR(net_device) : 0;
 }
 
 static int netvsc_set_channels(struct net_device *net,
@@ -845,8 +847,10 @@ static int netvsc_change_mtu(struct net_device *ndev, int mtu)
        struct net_device_context *ndevctx = netdev_priv(ndev);
        struct netvsc_device *nvdev = rtnl_dereference(ndevctx->nvdev);
        struct hv_device *hdev = ndevctx->device_ctx;
+       int orig_mtu = ndev->mtu;
        struct netvsc_device_info device_info;
        bool was_opened;
+       int ret = 0;
 
        if (!nvdev || nvdev->destroy)
                return -ENODEV;
@@ -863,16 +867,16 @@ static int netvsc_change_mtu(struct net_device *ndev, int mtu)
 
        rndis_filter_device_remove(hdev, nvdev);
 
-       /* 'nvdev' has been freed in rndis_filter_device_remove() ->
-        * netvsc_device_remove () -> free_netvsc_device().
-        * We mustn't access it before it's re-created in
-        * rndis_filter_device_add() -> netvsc_device_add().
-        */
-
        ndev->mtu = mtu;
 
-       rndis_filter_device_add(hdev, &device_info);
-       nvdev = rtnl_dereference(ndevctx->nvdev);
+       nvdev = rndis_filter_device_add(hdev, &device_info);
+       if (IS_ERR(nvdev)) {
+               ret = PTR_ERR(nvdev);
+
+               /* Attempt rollback to original MTU */
+               ndev->mtu = orig_mtu;
+               rndis_filter_device_add(hdev, &device_info);
+       }
 
        if (was_opened)
                rndis_filter_open(nvdev);
@@ -882,7 +886,7 @@ static int netvsc_change_mtu(struct net_device *ndev, int mtu)
        /* We may have missed link change notifications */
        schedule_delayed_work(&ndevctx->dwork, 0);
 
-       return 0;
+       return ret;
 }
 
 static void netvsc_get_stats64(struct net_device *net,
@@ -1525,8 +1529,10 @@ static int netvsc_probe(struct hv_device *dev,
        memset(&device_info, 0, sizeof(device_info));
        device_info.ring_size = ring_size;
        device_info.num_chn = VRSS_CHANNEL_DEFAULT;
-       ret = rndis_filter_device_add(dev, &device_info);
-       if (ret != 0) {
+
+       nvdev = rndis_filter_device_add(dev, &device_info);
+       if (IS_ERR(nvdev)) {
+               ret = PTR_ERR(nvdev);
                netdev_err(net, "unable to add netvsc device (ret %d)\n", ret);
                free_netdev(net);
                hv_set_drvdata(dev, NULL);
@@ -1540,11 +1546,11 @@ static int netvsc_probe(struct hv_device *dev,
                NETIF_F_HW_VLAN_CTAG_TX | NETIF_F_HW_VLAN_CTAG_RX;
        net->vlan_features = net->features;
 
-       /* RCU not necessary here, device not registered */
-       nvdev = net_device_ctx->nvdev;
        netif_set_real_num_tx_queues(net, nvdev->num_chn);
        netif_set_real_num_rx_queues(net, nvdev->num_chn);
 
+       netdev_lockdep_set_classes(net);
+
        /* MTU range: 68 - 1500 or 65521 */
        net->min_mtu = NETVSC_MTU_MIN;
        if (nvdev->nvsp_version >= NVSP_PROTOCOL_VERSION_2)
index 313c6d00d7d9109da532bf45824b3d6174777b59..cacf1e5536f795e1254363cbefd2224fab2bc0f9 100644 (file)
@@ -658,9 +658,9 @@ cleanup:
 
 static int
 rndis_filter_set_offload_params(struct net_device *ndev,
+                               struct netvsc_device *nvdev,
                                struct ndis_offload_params *req_offloads)
 {
-       struct netvsc_device *nvdev = net_device_to_netvsc_device(ndev);
        struct rndis_device *rdev = nvdev->extension;
        struct rndis_request *request;
        struct rndis_set_request *set;
@@ -1052,8 +1052,8 @@ static void netvsc_sc_open(struct vmbus_channel *new_sc)
                complete(&nvscdev->channel_init_wait);
 }
 
-int rndis_filter_device_add(struct hv_device *dev,
-                           struct netvsc_device_info *device_info)
+struct netvsc_device *rndis_filter_device_add(struct hv_device *dev,
+                                     struct netvsc_device_info *device_info)
 {
        struct net_device *net = hv_get_drvdata(dev);
        struct net_device_context *net_device_ctx = netdev_priv(net);
@@ -1072,21 +1072,20 @@ int rndis_filter_device_add(struct hv_device *dev,
 
        rndis_device = get_rndis_device();
        if (!rndis_device)
-               return -ENODEV;
+               return ERR_PTR(-ENODEV);
 
        /*
         * Let the inner driver handle this first to create the netvsc channel
         * NOTE! Once the channel is created, we may get a receive callback
         * (RndisFilterOnReceive()) before this call is completed
         */
-       ret = netvsc_device_add(dev, device_info);
-       if (ret != 0) {
+       net_device = netvsc_device_add(dev, device_info);
+       if (IS_ERR(net_device)) {
                kfree(rndis_device);
-               return ret;
+               return net_device;
        }
 
        /* Initialize the rndis device */
-       net_device = net_device_ctx->nvdev;
        net_device->max_chn = 1;
        net_device->num_chn = 1;
 
@@ -1097,10 +1096,8 @@ int rndis_filter_device_add(struct hv_device *dev,
 
        /* Send the rndis initialization message */
        ret = rndis_filter_init_device(rndis_device);
-       if (ret != 0) {
-               rndis_filter_device_remove(dev, net_device);
-               return ret;
-       }
+       if (ret != 0)
+               goto err_dev_remv;
 
        /* Get the MTU from the host */
        size = sizeof(u32);
@@ -1112,19 +1109,15 @@ int rndis_filter_device_add(struct hv_device *dev,
 
        /* Get the mac address */
        ret = rndis_filter_query_device_mac(rndis_device);
-       if (ret != 0) {
-               rndis_filter_device_remove(dev, net_device);
-               return ret;
-       }
+       if (ret != 0)
+               goto err_dev_remv;
 
        memcpy(device_info->mac_adr, rndis_device->hw_mac_adr, ETH_ALEN);
 
        /* Find HW offload capabilities */
        ret = rndis_query_hwcaps(rndis_device, &hwcaps);
-       if (ret != 0) {
-               rndis_filter_device_remove(dev, net_device);
-               return ret;
-       }
+       if (ret != 0)
+               goto err_dev_remv;
 
        /* A value of zero means "no change"; now turn on what we want. */
        memset(&offloads, 0, sizeof(struct ndis_offload_params));
@@ -1179,7 +1172,7 @@ int rndis_filter_device_add(struct hv_device *dev,
 
        netif_set_gso_max_size(net, gso_max_size);
 
-       ret = rndis_filter_set_offload_params(net, &offloads);
+       ret = rndis_filter_set_offload_params(net, net_device, &offloads);
        if (ret)
                goto err_dev_remv;
 
@@ -1190,7 +1183,7 @@ int rndis_filter_device_add(struct hv_device *dev,
                   rndis_device->link_state ? "down" : "up");
 
        if (net_device->nvsp_version < NVSP_PROTOCOL_VERSION_5)
-               return 0;
+               return net_device;
 
        rndis_filter_query_link_speed(rndis_device);
 
@@ -1223,7 +1216,7 @@ int rndis_filter_device_add(struct hv_device *dev,
 
        num_rss_qs = net_device->num_chn - 1;
        if (num_rss_qs == 0)
-               return 0;
+               return net_device;
 
        refcount_set(&net_device->sc_offered, num_rss_qs);
        vmbus_set_sc_create_callback(dev->channel, netvsc_sc_open);
@@ -1260,11 +1253,11 @@ out:
                net_device->num_chn = 1;
        }
 
-       return 0; /* return 0 because primary channel can be used alone */
+       return net_device;
 
 err_dev_remv:
        rndis_filter_device_remove(dev, net_device);
-       return ret;
+       return ERR_PTR(ret);
 }
 
 void rndis_filter_device_remove(struct hv_device *dev,