AMD IOMMU: check for invalid device pointers
authorJoerg Roedel <joerg.roedel@amd.com>
Thu, 4 Sep 2008 13:04:26 +0000 (15:04 +0200)
committerIngo Molnar <mingo@elte.hu>
Fri, 19 Sep 2008 10:59:03 +0000 (12:59 +0200)
Currently AMD IOMMU code triggers a BUG_ON if NULL is passed as the
device. This is inconsistent with other IOMMU implementations.

Signed-off-by: Joerg Roedel <joerg.roedel@amd.com>
Signed-off-by: Ingo Molnar <mingo@elte.hu>
arch/x86/kernel/amd_iommu.c

index 01c68c38840d8b8997e5d801e508dd20185605f0..695e0fc41b108389242ceb41fbf872af004f75df 100644 (file)
@@ -645,6 +645,18 @@ static void set_device_domain(struct amd_iommu *iommu,
  *
  *****************************************************************************/
 
+/*
+ * This function checks if the driver got a valid device from the caller to
+ * avoid dereferencing invalid pointers.
+ */
+static bool check_device(struct device *dev)
+{
+       if (!dev || !dev->dma_mask)
+               return false;
+
+       return true;
+}
+
 /*
  * In the dma_ops path we only have the struct device. This function
  * finds the corresponding IOMMU, the protection domain and the
@@ -661,18 +673,19 @@ static int get_device_resources(struct device *dev,
        struct pci_dev *pcidev;
        u16 _bdf;
 
-       BUG_ON(!dev || dev->bus != &pci_bus_type || !dev->dma_mask);
+       *iommu = NULL;
+       *domain = NULL;
+       *bdf = 0xffff;
+
+       if (dev->bus != &pci_bus_type)
+               return 0;
 
        pcidev = to_pci_dev(dev);
        _bdf = calc_devid(pcidev->bus->number, pcidev->devfn);
 
        /* device not translated by any IOMMU in the system? */
-       if (_bdf > amd_iommu_last_bdf) {
-               *iommu = NULL;
-               *domain = NULL;
-               *bdf = 0xffff;
+       if (_bdf > amd_iommu_last_bdf)
                return 0;
-       }
 
        *bdf = amd_iommu_alias_table[_bdf];
 
@@ -826,6 +839,9 @@ static dma_addr_t map_single(struct device *dev, phys_addr_t paddr,
        u16 devid;
        dma_addr_t addr;
 
+       if (!check_device(dev))
+               return bad_dma_address;
+
        get_device_resources(dev, &iommu, &domain, &devid);
 
        if (iommu == NULL || domain == NULL)
@@ -860,7 +876,8 @@ static void unmap_single(struct device *dev, dma_addr_t dma_addr,
        struct protection_domain *domain;
        u16 devid;
 
-       if (!get_device_resources(dev, &iommu, &domain, &devid))
+       if (!check_device(dev) ||
+           !get_device_resources(dev, &iommu, &domain, &devid))
                /* device not handled by any AMD IOMMU */
                return;
 
@@ -910,6 +927,9 @@ static int map_sg(struct device *dev, struct scatterlist *sglist,
        phys_addr_t paddr;
        int mapped_elems = 0;
 
+       if (!check_device(dev))
+               return 0;
+
        get_device_resources(dev, &iommu, &domain, &devid);
 
        if (!iommu || !domain)
@@ -967,7 +987,8 @@ static void unmap_sg(struct device *dev, struct scatterlist *sglist,
        u16 devid;
        int i;
 
-       if (!get_device_resources(dev, &iommu, &domain, &devid))
+       if (!check_device(dev) ||
+           !get_device_resources(dev, &iommu, &domain, &devid))
                return;
 
        spin_lock_irqsave(&domain->lock, flags);
@@ -999,6 +1020,9 @@ static void *alloc_coherent(struct device *dev, size_t size,
        u16 devid;
        phys_addr_t paddr;
 
+       if (!check_device(dev))
+               return NULL;
+
        virt_addr = (void *)__get_free_pages(flag, get_order(size));
        if (!virt_addr)
                return 0;
@@ -1047,6 +1071,9 @@ static void free_coherent(struct device *dev, size_t size,
        struct protection_domain *domain;
        u16 devid;
 
+       if (!check_device(dev))
+               return;
+
        get_device_resources(dev, &iommu, &domain, &devid);
 
        if (!iommu || !domain)