staging: hv: Fix race condition on vmbus channel initialization
authorHaiyang Zhang <haiyangz@microsoft.com>
Fri, 28 May 2010 23:22:44 +0000 (23:22 +0000)
committerGreg Kroah-Hartman <gregkh@suse.de>
Wed, 30 Jun 2010 15:18:14 +0000 (08:18 -0700)
There is a possible race condition when hv_utils starts to load immediately
after hv_vmbus is loading - null pointer error could happen.
This patch added wait/completion to ensure all channels are ready before
vmbus loading completes. So another module won't have any uninitialized channel.

Signed-off-by: Haiyang Zhang <haiyangz@microsoft.com>
Signed-off-by: Hank Janssen <hjanssen@microsoft.com>
Signed-off-by: Greg Kroah-Hartman <gregkh@suse.de>
drivers/staging/hv/channel_mgmt.c
drivers/staging/hv/vmbus.h
drivers/staging/hv/vmbus_drv.c

index 3f53b4d1e4cffcf09552871b30f5c9d6f98ce8a3..12db555a3a5d59b1348c4f3616e246991f7c3246 100644 (file)
@@ -23,6 +23,7 @@
 #include <linux/slab.h>
 #include <linux/list.h>
 #include <linux/module.h>
+#include <linux/completion.h>
 #include "osd.h"
 #include "logging.h"
 #include "vmbus_private.h"
@@ -293,6 +294,25 @@ void FreeVmbusChannel(struct vmbus_channel *Channel)
                              Channel);
 }
 
+
+DECLARE_COMPLETION(hv_channel_ready);
+
+/*
+ * Count initialized channels, and ensure all channels are ready when hv_vmbus
+ * module loading completes.
+ */
+static void count_hv_channel(void)
+{
+       static int counter;
+       unsigned long flags;
+
+       spin_lock_irqsave(&gVmbusConnection.channel_lock, flags);
+       if (++counter == MAX_MSG_TYPES)
+               complete(&hv_channel_ready);
+       spin_unlock_irqrestore(&gVmbusConnection.channel_lock, flags);
+}
+
+
 /*
  * VmbusChannelProcessOffer - Process the offer by creating a channel/device
  * associated with this offer
@@ -373,22 +393,21 @@ static void VmbusChannelProcessOffer(void *context)
                 * can cleanup properly
                 */
                newChannel->State = CHANNEL_OPEN_STATE;
-               cnt = 0;
 
-               while (cnt != MAX_MSG_TYPES) {
+               /* Open IC channels */
+               for (cnt = 0; cnt < MAX_MSG_TYPES; cnt++) {
                        if (memcmp(&newChannel->OfferMsg.Offer.InterfaceType,
                                   &hv_cb_utils[cnt].data,
-                                  sizeof(struct hv_guid)) == 0) {
+                                  sizeof(struct hv_guid)) == 0 &&
+                               VmbusChannelOpen(newChannel, 2 * PAGE_SIZE,
+                                                2 * PAGE_SIZE, NULL, 0,
+                                                hv_cb_utils[cnt].callback,
+                                                newChannel) == 0) {
+                               hv_cb_utils[cnt].channel = newChannel;
                                DPRINT_INFO(VMBUS, "%s",
-                                           hv_cb_utils[cnt].log_msg);
-
-                               if (VmbusChannelOpen(newChannel, 2 * PAGE_SIZE,
-                                                   2 * PAGE_SIZE, NULL, 0,
-                                                   hv_cb_utils[cnt].callback,
-                                                   newChannel) == 0)
-                                       hv_cb_utils[cnt].channel = newChannel;
+                                               hv_cb_utils[cnt].log_msg);
+                               count_hv_channel();
                        }
-                       cnt++;
                }
        }
        DPRINT_EXIT(VMBUS);
index 0c6ee0f487f3538384e0d4f2931b2f496b6a01e8..3c14b2926e0018e1e8ef157df37fc770a6d440b3 100644 (file)
@@ -74,4 +74,6 @@ int vmbus_child_driver_register(struct driver_context *driver_ctx);
 void vmbus_child_driver_unregister(struct driver_context *driver_ctx);
 void vmbus_get_interface(struct vmbus_channel_interface *interface);
 
+extern struct completion hv_channel_ready;
+
 #endif /* _VMBUS_H_ */
index c21731a12ca7a7a8e8a9929369b2f99833c5e2f2..22c80ece6388c5e72661864105f4d35ebf7e1824 100644 (file)
@@ -27,6 +27,7 @@
 #include <linux/pci.h>
 #include <linux/dmi.h>
 #include <linux/slab.h>
+#include <linux/completion.h>
 #include "version_info.h"
 #include "osd.h"
 #include "logging.h"
@@ -356,6 +357,8 @@ static int vmbus_bus_init(int (*drv_init)(struct hv_driver *drv))
 
        vmbus_drv_obj->GetChannelOffers();
 
+       wait_for_completion(&hv_channel_ready);
+
 cleanup:
        DPRINT_EXIT(VMBUS_DRV);