IB/core: Implement support for MMU notifiers regarding on demand paging regions
authorHaggai Eran <haggaie@mellanox.com>
Thu, 11 Dec 2014 15:04:18 +0000 (17:04 +0200)
committerRoland Dreier <roland@purestorage.com>
Tue, 16 Dec 2014 02:13:36 +0000 (18:13 -0800)
* Add an interval tree implementation for ODP umems. Create an
  interval tree for each ucontext (including a count of the number of
  ODP MRs in this context, semaphore, etc.), and register ODP umems in
  the interval tree.
* Add MMU notifiers handling functions, using the interval tree to
  notify only the relevant umems and underlying MRs.
* Register to receive MMU notifier events from the MM subsystem upon
  ODP MR registration (and unregister accordingly).
* Add a completion object to synchronize the destruction of ODP umems.
* Add mechanism to abort page faults when there's a concurrent invalidation.

The way we synchronize between concurrent invalidations and page
faults is by keeping a counter of currently running invalidations, and
a sequence number that is incremented whenever an invalidation is
caught. The page fault code checks the counter and also verifies that
the sequence number hasn't progressed before it updates the umem's
page tables. This is similar to what the kvm module does.

In order to prevent the case where we register a umem in the middle of
an ongoing notifier, we also keep a per ucontext counter of the total
number of active mmu notifiers. We only enable new umems when all the
running notifiers complete.

Signed-off-by: Sagi Grimberg <sagig@mellanox.com>
Signed-off-by: Shachar Raindel <raindel@mellanox.com>
Signed-off-by: Haggai Eran <haggaie@mellanox.com>
Signed-off-by: Yuval Dagan <yuvalda@mellanox.com>
Signed-off-by: Roland Dreier <roland@purestorage.com>
drivers/infiniband/Kconfig
drivers/infiniband/core/Makefile
drivers/infiniband/core/umem.c
drivers/infiniband/core/umem_odp.c
drivers/infiniband/core/umem_rbtree.c [new file with mode: 0644]
drivers/infiniband/core/uverbs_cmd.c
include/rdma/ib_umem_odp.h
include/rdma/ib_verbs.h

index 089a2c2af329e478c3fdfcbb6ff08ccbdfae81d0..b899531498eb0dc7924e4587f67dbc3f7c116313 100644 (file)
@@ -41,6 +41,7 @@ config INFINIBAND_USER_MEM
 config INFINIBAND_ON_DEMAND_PAGING
        bool "InfiniBand on-demand paging support"
        depends on INFINIBAND_USER_MEM
+       select MMU_NOTIFIER
        default y
        ---help---
          On demand paging support for the InfiniBand subsystem.
index c58f7913c5603ee7b9b41ae2e7be03b97ebf900e..acf73676444593704267ac9176696f95caa52335 100644 (file)
@@ -11,7 +11,7 @@ obj-$(CONFIG_INFINIBAND_USER_ACCESS) +=       ib_uverbs.o ib_ucm.o \
 ib_core-y :=                   packer.o ud_header.o verbs.o sysfs.o \
                                device.o fmr_pool.o cache.o netlink.o
 ib_core-$(CONFIG_INFINIBAND_USER_MEM) += umem.o
-ib_core-$(CONFIG_INFINIBAND_ON_DEMAND_PAGING) += umem_odp.o
+ib_core-$(CONFIG_INFINIBAND_ON_DEMAND_PAGING) += umem_odp.o umem_rbtree.o
 
 ib_mad-y :=                    mad.o smi.o agent.o mad_rmpp.o
 
index 5baceb79f21b3c526aefa0a510b7e730061bfe0a..aec7a6aa2951db47bc6b5be969a29d1867688b23 100644 (file)
@@ -72,7 +72,7 @@ static void __ib_umem_release(struct ib_device *dev, struct ib_umem *umem, int d
  * ib_umem_get - Pin and DMA map userspace memory.
  *
  * If access flags indicate ODP memory, avoid pinning. Instead, stores
- * the mm for future page fault handling.
+ * the mm for future page fault handling in conjunction with MMU notifiers.
  *
  * @context: userspace context to pin memory for
  * @addr: userspace virtual address to start at
index f889e8d793bd215727eb49dbdf9d43e1088646dd..6095872549e79fb0dd9c0b3702c4eb01e974fb3f 100644 (file)
 #include <rdma/ib_umem.h>
 #include <rdma/ib_umem_odp.h>
 
+static void ib_umem_notifier_start_account(struct ib_umem *item)
+{
+       mutex_lock(&item->odp_data->umem_mutex);
+
+       /* Only update private counters for this umem if it has them.
+        * Otherwise skip it. All page faults will be delayed for this umem. */
+       if (item->odp_data->mn_counters_active) {
+               int notifiers_count = item->odp_data->notifiers_count++;
+
+               if (notifiers_count == 0)
+                       /* Initialize the completion object for waiting on
+                        * notifiers. Since notifier_count is zero, no one
+                        * should be waiting right now. */
+                       reinit_completion(&item->odp_data->notifier_completion);
+       }
+       mutex_unlock(&item->odp_data->umem_mutex);
+}
+
+static void ib_umem_notifier_end_account(struct ib_umem *item)
+{
+       mutex_lock(&item->odp_data->umem_mutex);
+
+       /* Only update private counters for this umem if it has them.
+        * Otherwise skip it. All page faults will be delayed for this umem. */
+       if (item->odp_data->mn_counters_active) {
+               /*
+                * This sequence increase will notify the QP page fault that
+                * the page that is going to be mapped in the spte could have
+                * been freed.
+                */
+               ++item->odp_data->notifiers_seq;
+               if (--item->odp_data->notifiers_count == 0)
+                       complete_all(&item->odp_data->notifier_completion);
+       }
+       mutex_unlock(&item->odp_data->umem_mutex);
+}
+
+/* Account for a new mmu notifier in an ib_ucontext. */
+static void ib_ucontext_notifier_start_account(struct ib_ucontext *context)
+{
+       atomic_inc(&context->notifier_count);
+}
+
+/* Account for a terminating mmu notifier in an ib_ucontext.
+ *
+ * Must be called with the ib_ucontext->umem_rwsem semaphore unlocked, since
+ * the function takes the semaphore itself. */
+static void ib_ucontext_notifier_end_account(struct ib_ucontext *context)
+{
+       int zero_notifiers = atomic_dec_and_test(&context->notifier_count);
+
+       if (zero_notifiers &&
+           !list_empty(&context->no_private_counters)) {
+               /* No currently running mmu notifiers. Now is the chance to
+                * add private accounting to all previously added umems. */
+               struct ib_umem_odp *odp_data, *next;
+
+               /* Prevent concurrent mmu notifiers from working on the
+                * no_private_counters list. */
+               down_write(&context->umem_rwsem);
+
+               /* Read the notifier_count again, with the umem_rwsem
+                * semaphore taken for write. */
+               if (!atomic_read(&context->notifier_count)) {
+                       list_for_each_entry_safe(odp_data, next,
+                                                &context->no_private_counters,
+                                                no_private_counters) {
+                               mutex_lock(&odp_data->umem_mutex);
+                               odp_data->mn_counters_active = true;
+                               list_del(&odp_data->no_private_counters);
+                               complete_all(&odp_data->notifier_completion);
+                               mutex_unlock(&odp_data->umem_mutex);
+                       }
+               }
+
+               up_write(&context->umem_rwsem);
+       }
+}
+
+static int ib_umem_notifier_release_trampoline(struct ib_umem *item, u64 start,
+                                              u64 end, void *cookie) {
+       /*
+        * Increase the number of notifiers running, to
+        * prevent any further fault handling on this MR.
+        */
+       ib_umem_notifier_start_account(item);
+       item->odp_data->dying = 1;
+       /* Make sure that the fact the umem is dying is out before we release
+        * all pending page faults. */
+       smp_wmb();
+       complete_all(&item->odp_data->notifier_completion);
+       item->context->invalidate_range(item, ib_umem_start(item),
+                                       ib_umem_end(item));
+       return 0;
+}
+
+static void ib_umem_notifier_release(struct mmu_notifier *mn,
+                                    struct mm_struct *mm)
+{
+       struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
+
+       if (!context->invalidate_range)
+               return;
+
+       ib_ucontext_notifier_start_account(context);
+       down_read(&context->umem_rwsem);
+       rbt_ib_umem_for_each_in_range(&context->umem_tree, 0,
+                                     ULLONG_MAX,
+                                     ib_umem_notifier_release_trampoline,
+                                     NULL);
+       up_read(&context->umem_rwsem);
+}
+
+static int invalidate_page_trampoline(struct ib_umem *item, u64 start,
+                                     u64 end, void *cookie)
+{
+       ib_umem_notifier_start_account(item);
+       item->context->invalidate_range(item, start, start + PAGE_SIZE);
+       ib_umem_notifier_end_account(item);
+       return 0;
+}
+
+static void ib_umem_notifier_invalidate_page(struct mmu_notifier *mn,
+                                            struct mm_struct *mm,
+                                            unsigned long address)
+{
+       struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
+
+       if (!context->invalidate_range)
+               return;
+
+       ib_ucontext_notifier_start_account(context);
+       down_read(&context->umem_rwsem);
+       rbt_ib_umem_for_each_in_range(&context->umem_tree, address,
+                                     address + PAGE_SIZE,
+                                     invalidate_page_trampoline, NULL);
+       up_read(&context->umem_rwsem);
+       ib_ucontext_notifier_end_account(context);
+}
+
+static int invalidate_range_start_trampoline(struct ib_umem *item, u64 start,
+                                            u64 end, void *cookie)
+{
+       ib_umem_notifier_start_account(item);
+       item->context->invalidate_range(item, start, end);
+       return 0;
+}
+
+static void ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
+                                                   struct mm_struct *mm,
+                                                   unsigned long start,
+                                                   unsigned long end)
+{
+       struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
+
+       if (!context->invalidate_range)
+               return;
+
+       ib_ucontext_notifier_start_account(context);
+       down_read(&context->umem_rwsem);
+       rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
+                                     end,
+                                     invalidate_range_start_trampoline, NULL);
+       up_read(&context->umem_rwsem);
+}
+
+static int invalidate_range_end_trampoline(struct ib_umem *item, u64 start,
+                                          u64 end, void *cookie)
+{
+       ib_umem_notifier_end_account(item);
+       return 0;
+}
+
+static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
+                                                 struct mm_struct *mm,
+                                                 unsigned long start,
+                                                 unsigned long end)
+{
+       struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
+
+       if (!context->invalidate_range)
+               return;
+
+       down_read(&context->umem_rwsem);
+       rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
+                                     end,
+                                     invalidate_range_end_trampoline, NULL);
+       up_read(&context->umem_rwsem);
+       ib_ucontext_notifier_end_account(context);
+}
+
+static struct mmu_notifier_ops ib_umem_notifiers = {
+       .release                    = ib_umem_notifier_release,
+       .invalidate_page            = ib_umem_notifier_invalidate_page,
+       .invalidate_range_start     = ib_umem_notifier_invalidate_range_start,
+       .invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
+};
+
 int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem)
 {
        int ret_val;
        struct pid *our_pid;
+       struct mm_struct *mm = get_task_mm(current);
+
+       if (!mm)
+               return -EINVAL;
 
        /* Prevent creating ODP MRs in child processes */
        rcu_read_lock();
        our_pid = get_task_pid(current->group_leader, PIDTYPE_PID);
        rcu_read_unlock();
        put_pid(our_pid);
-       if (context->tgid != our_pid)
-               return -EINVAL;
+       if (context->tgid != our_pid) {
+               ret_val = -EINVAL;
+               goto out_mm;
+       }
 
        umem->hugetlb = 0;
        umem->odp_data = kzalloc(sizeof(*umem->odp_data), GFP_KERNEL);
-       if (!umem->odp_data)
-               return -ENOMEM;
+       if (!umem->odp_data) {
+               ret_val = -ENOMEM;
+               goto out_mm;
+       }
+       umem->odp_data->umem = umem;
 
        mutex_init(&umem->odp_data->umem_mutex);
 
+       init_completion(&umem->odp_data->notifier_completion);
+
        umem->odp_data->page_list = vzalloc(ib_umem_num_pages(umem) *
                                            sizeof(*umem->odp_data->page_list));
        if (!umem->odp_data->page_list) {
@@ -75,17 +284,72 @@ int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem)
                goto out_page_list;
        }
 
+       /*
+        * When using MMU notifiers, we will get a
+        * notification before the "current" task (and MM) is
+        * destroyed. We use the umem_rwsem semaphore to synchronize.
+        */
+       down_write(&context->umem_rwsem);
+       context->odp_mrs_count++;
+       if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
+               rbt_ib_umem_insert(&umem->odp_data->interval_tree,
+                                  &context->umem_tree);
+       if (likely(!atomic_read(&context->notifier_count)))
+               umem->odp_data->mn_counters_active = true;
+       else
+               list_add(&umem->odp_data->no_private_counters,
+                        &context->no_private_counters);
+       downgrade_write(&context->umem_rwsem);
+
+       if (context->odp_mrs_count == 1) {
+               /*
+                * Note that at this point, no MMU notifier is running
+                * for this context!
+                */
+               atomic_set(&context->notifier_count, 0);
+               INIT_HLIST_NODE(&context->mn.hlist);
+               context->mn.ops = &ib_umem_notifiers;
+               /*
+                * Lock-dep detects a false positive for mmap_sem vs.
+                * umem_rwsem, due to not grasping downgrade_write correctly.
+                */
+               lockdep_off();
+               ret_val = mmu_notifier_register(&context->mn, mm);
+               lockdep_on();
+               if (ret_val) {
+                       pr_err("Failed to register mmu_notifier %d\n", ret_val);
+                       ret_val = -EBUSY;
+                       goto out_mutex;
+               }
+       }
+
+       up_read(&context->umem_rwsem);
+
+       /*
+        * Note that doing an mmput can cause a notifier for the relevant mm.
+        * If the notifier is called while we hold the umem_rwsem, this will
+        * cause a deadlock. Therefore, we release the reference only after we
+        * released the semaphore.
+        */
+       mmput(mm);
        return 0;
 
+out_mutex:
+       up_read(&context->umem_rwsem);
+       vfree(umem->odp_data->dma_list);
 out_page_list:
        vfree(umem->odp_data->page_list);
 out_odp_data:
        kfree(umem->odp_data);
+out_mm:
+       mmput(mm);
        return ret_val;
 }
 
 void ib_umem_odp_release(struct ib_umem *umem)
 {
+       struct ib_ucontext *context = umem->context;
+
        /*
         * Ensure that no more pages are mapped in the umem.
         *
@@ -95,6 +359,54 @@ void ib_umem_odp_release(struct ib_umem *umem)
        ib_umem_odp_unmap_dma_pages(umem, ib_umem_start(umem),
                                    ib_umem_end(umem));
 
+       down_write(&context->umem_rwsem);
+       if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
+               rbt_ib_umem_remove(&umem->odp_data->interval_tree,
+                                  &context->umem_tree);
+       context->odp_mrs_count--;
+       if (!umem->odp_data->mn_counters_active) {
+               list_del(&umem->odp_data->no_private_counters);
+               complete_all(&umem->odp_data->notifier_completion);
+       }
+
+       /*
+        * Downgrade the lock to a read lock. This ensures that the notifiers
+        * (who lock the mutex for reading) will be able to finish, and we
+        * will be able to enventually obtain the mmu notifiers SRCU. Note
+        * that since we are doing it atomically, no other user could register
+        * and unregister while we do the check.
+        */
+       downgrade_write(&context->umem_rwsem);
+       if (!context->odp_mrs_count) {
+               struct task_struct *owning_process = NULL;
+               struct mm_struct *owning_mm        = NULL;
+
+               owning_process = get_pid_task(context->tgid,
+                                             PIDTYPE_PID);
+               if (owning_process == NULL)
+                       /*
+                        * The process is already dead, notifier were removed
+                        * already.
+                        */
+                       goto out;
+
+               owning_mm = get_task_mm(owning_process);
+               if (owning_mm == NULL)
+                       /*
+                        * The process' mm is already dead, notifier were
+                        * removed already.
+                        */
+                       goto out_put_task;
+               mmu_notifier_unregister(&context->mn, owning_mm);
+
+               mmput(owning_mm);
+
+out_put_task:
+               put_task_struct(owning_process);
+       }
+out:
+       up_read(&context->umem_rwsem);
+
        vfree(umem->odp_data->dma_list);
        vfree(umem->odp_data->page_list);
        kfree(umem->odp_data);
@@ -112,7 +424,8 @@ void ib_umem_odp_release(struct ib_umem *umem)
  *               the sequence number is taken from
  *               umem->odp_data->notifiers_seq.
  *
- * The function returns -EFAULT if the DMA mapping operation fails.
+ * The function returns -EFAULT if the DMA mapping operation fails. It returns
+ * -EAGAIN if a concurrent invalidation prevents us from updating the page.
  *
  * The page is released via put_page even if the operation failed. For
  * on-demand pinning, the page is released whenever it isn't stored in the
@@ -121,6 +434,7 @@ void ib_umem_odp_release(struct ib_umem *umem)
 static int ib_umem_odp_map_dma_single_page(
                struct ib_umem *umem,
                int page_index,
+               u64 base_virt_addr,
                struct page *page,
                u64 access_mask,
                unsigned long current_seq)
@@ -128,9 +442,19 @@ static int ib_umem_odp_map_dma_single_page(
        struct ib_device *dev = umem->context->device;
        dma_addr_t dma_addr;
        int stored_page = 0;
+       int remove_existing_mapping = 0;
        int ret = 0;
 
        mutex_lock(&umem->odp_data->umem_mutex);
+       /*
+        * Note: we avoid writing if seq is different from the initial seq, to
+        * handle case of a racing notifier. This check also allows us to bail
+        * early if we have a notifier running in parallel with us.
+        */
+       if (ib_umem_mmu_notifier_retry(umem, current_seq)) {
+               ret = -EAGAIN;
+               goto out;
+       }
        if (!(umem->odp_data->dma_list[page_index])) {
                dma_addr = ib_dma_map_page(dev,
                                           page,
@@ -148,14 +472,27 @@ static int ib_umem_odp_map_dma_single_page(
        } else {
                pr_err("error: got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
                       umem->odp_data->page_list[page_index], page);
+               /* Better remove the mapping now, to prevent any further
+                * damage. */
+               remove_existing_mapping = 1;
        }
 
 out:
        mutex_unlock(&umem->odp_data->umem_mutex);
 
-       if (!stored_page)
+       /* On Demand Paging - avoid pinning the page */
+       if (umem->context->invalidate_range || !stored_page)
                put_page(page);
 
+       if (remove_existing_mapping && umem->context->invalidate_range) {
+               invalidate_page_trampoline(
+                       umem,
+                       base_virt_addr + (page_index * PAGE_SIZE),
+                       base_virt_addr + ((page_index+1)*PAGE_SIZE),
+                       NULL);
+               ret = -EAGAIN;
+       }
+
        return ret;
 }
 
@@ -168,6 +505,8 @@ out:
  *
  * Returns the number of pages mapped in success, negative error code
  * for failure.
+ * An -EAGAIN error code is returned when a concurrent mmu notifier prevents
+ * the function from completing its task.
  *
  * @umem: the umem to map and pin
  * @user_virt: the address from which we need to map.
@@ -189,6 +528,7 @@ int ib_umem_odp_map_dma_pages(struct ib_umem *umem, u64 user_virt, u64 bcnt,
        struct page       **local_page_list = NULL;
        u64 off;
        int j, k, ret = 0, start_idx, npages = 0;
+       u64 base_virt_addr;
 
        if (access_mask == 0)
                return -EINVAL;
@@ -203,6 +543,7 @@ int ib_umem_odp_map_dma_pages(struct ib_umem *umem, u64 user_virt, u64 bcnt,
 
        off = user_virt & (~PAGE_MASK);
        user_virt = user_virt & PAGE_MASK;
+       base_virt_addr = user_virt;
        bcnt += off; /* Charge for the first page offset as well. */
 
        owning_process = get_pid_task(umem->context->tgid, PIDTYPE_PID);
@@ -246,8 +587,8 @@ int ib_umem_odp_map_dma_pages(struct ib_umem *umem, u64 user_virt, u64 bcnt,
                user_virt += npages << PAGE_SHIFT;
                for (j = 0; j < npages; ++j) {
                        ret = ib_umem_odp_map_dma_single_page(
-                               umem, k, local_page_list[j], access_mask,
-                               current_seq);
+                               umem, k, base_virt_addr, local_page_list[j],
+                               access_mask, current_seq);
                        if (ret < 0)
                                break;
                        k++;
@@ -286,6 +627,11 @@ void ib_umem_odp_unmap_dma_pages(struct ib_umem *umem, u64 virt,
 
        virt  = max_t(u64, virt,  ib_umem_start(umem));
        bound = min_t(u64, bound, ib_umem_end(umem));
+       /* Note that during the run of this function, the
+        * notifiers_count of the MR is > 0, preventing any racing
+        * faults from completion. We might be racing with other
+        * invalidations, so we must make sure we free each page only
+        * once. */
        for (addr = virt; addr < bound; addr += (u64)umem->page_size) {
                idx = (addr - ib_umem_start(umem)) / PAGE_SIZE;
                mutex_lock(&umem->odp_data->umem_mutex);
@@ -300,8 +646,21 @@ void ib_umem_odp_unmap_dma_pages(struct ib_umem *umem, u64 virt,
                        ib_dma_unmap_page(dev, dma_addr, PAGE_SIZE,
                                          DMA_BIDIRECTIONAL);
                        if (dma & ODP_WRITE_ALLOWED_BIT)
-                               set_page_dirty_lock(head_page);
-                       put_page(page);
+                               /*
+                                * set_page_dirty prefers being called with
+                                * the page lock. However, MMU notifiers are
+                                * called sometimes with and sometimes without
+                                * the lock. We rely on the umem_mutex instead
+                                * to prevent other mmu notifiers from
+                                * continuing and allowing the page mapping to
+                                * be removed.
+                                */
+                               set_page_dirty(head_page);
+                       /* on demand pinning support */
+                       if (!umem->context->invalidate_range)
+                               put_page(page);
+                       umem->odp_data->page_list[idx] = NULL;
+                       umem->odp_data->dma_list[idx] = 0;
                }
                mutex_unlock(&umem->odp_data->umem_mutex);
        }
diff --git a/drivers/infiniband/core/umem_rbtree.c b/drivers/infiniband/core/umem_rbtree.c
new file mode 100644 (file)
index 0000000..727d788
--- /dev/null
@@ -0,0 +1,94 @@
+/*
+ * Copyright (c) 2014 Mellanox Technologies. All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses.  You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the
+ * OpenIB.org BSD license below:
+ *
+ *     Redistribution and use in source and binary forms, with or
+ *     without modification, are permitted provided that the following
+ *     conditions are met:
+ *
+ *      - Redistributions of source code must retain the above
+ *        copyright notice, this list of conditions and the following
+ *        disclaimer.
+ *
+ *      - Redistributions in binary form must reproduce the above
+ *        copyright notice, this list of conditions and the following
+ *        disclaimer in the documentation and/or other materials
+ *        provided with the distribution.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+ * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+ * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include <linux/kernel.h>
+#include <linux/module.h>
+#include <linux/interval_tree_generic.h>
+#include <linux/sched.h>
+#include <linux/gfp.h>
+#include <rdma/ib_umem_odp.h>
+
+/*
+ * The ib_umem list keeps track of memory regions for which the HW
+ * device request to receive notification when the related memory
+ * mapping is changed.
+ *
+ * ib_umem_lock protects the list.
+ */
+
+static inline u64 node_start(struct umem_odp_node *n)
+{
+       struct ib_umem_odp *umem_odp =
+                       container_of(n, struct ib_umem_odp, interval_tree);
+
+       return ib_umem_start(umem_odp->umem);
+}
+
+/* Note that the representation of the intervals in the interval tree
+ * considers the ending point as contained in the interval, while the
+ * function ib_umem_end returns the first address which is not contained
+ * in the umem.
+ */
+static inline u64 node_last(struct umem_odp_node *n)
+{
+       struct ib_umem_odp *umem_odp =
+                       container_of(n, struct ib_umem_odp, interval_tree);
+
+       return ib_umem_end(umem_odp->umem) - 1;
+}
+
+INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
+                    node_start, node_last, , rbt_ib_umem)
+
+/* @last is not a part of the interval. See comment for function
+ * node_last.
+ */
+int rbt_ib_umem_for_each_in_range(struct rb_root *root,
+                                 u64 start, u64 last,
+                                 umem_call_back cb,
+                                 void *cookie)
+{
+       int ret_val = 0;
+       struct umem_odp_node *node;
+       struct ib_umem_odp *umem;
+
+       if (unlikely(start == last))
+               return ret_val;
+
+       for (node = rbt_ib_umem_iter_first(root, start, last - 1); node;
+                       node = rbt_ib_umem_iter_next(node, start, last - 1)) {
+               umem = container_of(node, struct ib_umem_odp, interval_tree);
+               ret_val = cb(umem->umem, start, last, cookie) || ret_val;
+       }
+
+       return ret_val;
+}
index 70b697d8fbb3e036afca13d220e09a922eee7b13..532d8eba8b0203ab65a2a8ed1f096389253b1dca 100644 (file)
@@ -289,6 +289,9 @@ ssize_t ib_uverbs_get_context(struct ib_uverbs_file *file,
        struct ib_uverbs_get_context_resp resp;
        struct ib_udata                   udata;
        struct ib_device                 *ibdev = file->device->ib_dev;
+#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
+       struct ib_device_attr             dev_attr;
+#endif
        struct ib_ucontext               *ucontext;
        struct file                      *filp;
        int ret;
@@ -331,6 +334,20 @@ ssize_t ib_uverbs_get_context(struct ib_uverbs_file *file,
        rcu_read_unlock();
        ucontext->closing = 0;
 
+#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
+       ucontext->umem_tree = RB_ROOT;
+       init_rwsem(&ucontext->umem_rwsem);
+       ucontext->odp_mrs_count = 0;
+       INIT_LIST_HEAD(&ucontext->no_private_counters);
+
+       ret = ib_query_device(ibdev, &dev_attr);
+       if (ret)
+               goto err_free;
+       if (!(dev_attr.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING))
+               ucontext->invalidate_range = NULL;
+
+#endif
+
        resp.num_comp_vectors = file->device->num_comp_vectors;
 
        ret = get_unused_fd_flags(O_CLOEXEC);
index b5a2df1923b7ae423d5cc3ba2c1c6a0b1a991406..3da0b167041b477e14a3e731e9de76e85148b46c 100644 (file)
 #define IB_UMEM_ODP_H
 
 #include <rdma/ib_umem.h>
+#include <rdma/ib_verbs.h>
+#include <linux/interval_tree.h>
+
+struct umem_odp_node {
+       u64 __subtree_last;
+       struct rb_node rb;
+};
 
 struct ib_umem_odp {
        /*
@@ -51,10 +58,27 @@ struct ib_umem_odp {
        dma_addr_t              *dma_list;
        /*
         * The umem_mutex protects the page_list and dma_list fields of an ODP
-        * umem, allowing only a single thread to map/unmap pages.
+        * umem, allowing only a single thread to map/unmap pages. The mutex
+        * also protects access to the mmu notifier counters.
         */
        struct mutex            umem_mutex;
        void                    *private; /* for the HW driver to use. */
+
+       /* When false, use the notifier counter in the ucontext struct. */
+       bool mn_counters_active;
+       int notifiers_seq;
+       int notifiers_count;
+
+       /* A linked list of umems that don't have private mmu notifier
+        * counters yet. */
+       struct list_head no_private_counters;
+       struct ib_umem          *umem;
+
+       /* Tree tracking */
+       struct umem_odp_node    interval_tree;
+
+       struct completion       notifier_completion;
+       int                     dying;
 };
 
 #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
@@ -82,6 +106,45 @@ int ib_umem_odp_map_dma_pages(struct ib_umem *umem, u64 start_offset, u64 bcnt,
 void ib_umem_odp_unmap_dma_pages(struct ib_umem *umem, u64 start_offset,
                                 u64 bound);
 
+void rbt_ib_umem_insert(struct umem_odp_node *node, struct rb_root *root);
+void rbt_ib_umem_remove(struct umem_odp_node *node, struct rb_root *root);
+typedef int (*umem_call_back)(struct ib_umem *item, u64 start, u64 end,
+                             void *cookie);
+/*
+ * Call the callback on each ib_umem in the range. Returns the logical or of
+ * the return values of the functions called.
+ */
+int rbt_ib_umem_for_each_in_range(struct rb_root *root, u64 start, u64 end,
+                                 umem_call_back cb, void *cookie);
+
+struct umem_odp_node *rbt_ib_umem_iter_first(struct rb_root *root,
+                                            u64 start, u64 last);
+struct umem_odp_node *rbt_ib_umem_iter_next(struct umem_odp_node *node,
+                                           u64 start, u64 last);
+
+static inline int ib_umem_mmu_notifier_retry(struct ib_umem *item,
+                                            unsigned long mmu_seq)
+{
+       /*
+        * This code is strongly based on the KVM code from
+        * mmu_notifier_retry. Should be called with
+        * the relevant locks taken (item->odp_data->umem_mutex
+        * and the ucontext umem_mutex semaphore locked for read).
+        */
+
+       /* Do not allow page faults while the new ib_umem hasn't seen a state
+        * with zero notifiers yet, and doesn't have its own valid set of
+        * private counters. */
+       if (!item->odp_data->mn_counters_active)
+               return 1;
+
+       if (unlikely(item->odp_data->notifiers_count))
+               return 1;
+       if (item->odp_data->notifiers_seq != mmu_seq)
+               return 1;
+       return 0;
+}
+
 #else /* CONFIG_INFINIBAND_ON_DEMAND_PAGING */
 
 static inline int ib_umem_odp_get(struct ib_ucontext *context,
index 3af5dcad1b691b40c148b019c178c10708744d0a..0d74f1de99aa89dee233ed408815459e2ad66d5f 100644 (file)
@@ -51,6 +51,7 @@
 #include <uapi/linux/if_ether.h>
 
 #include <linux/atomic.h>
+#include <linux/mmu_notifier.h>
 #include <asm/uaccess.h>
 
 extern struct workqueue_struct *ib_wq;
@@ -1139,6 +1140,8 @@ struct ib_fmr_attr {
        u8      page_shift;
 };
 
+struct ib_umem;
+
 struct ib_ucontext {
        struct ib_device       *device;
        struct list_head        pd_list;
@@ -1153,6 +1156,22 @@ struct ib_ucontext {
        int                     closing;
 
        struct pid             *tgid;
+#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
+       struct rb_root      umem_tree;
+       /*
+        * Protects .umem_rbroot and tree, as well as odp_mrs_count and
+        * mmu notifiers registration.
+        */
+       struct rw_semaphore     umem_rwsem;
+       void (*invalidate_range)(struct ib_umem *umem,
+                                unsigned long start, unsigned long end);
+
+       struct mmu_notifier     mn;
+       atomic_t                notifier_count;
+       /* A list of umems that don't have private mmu notifier counters yet. */
+       struct list_head        no_private_counters;
+       int                     odp_mrs_count;
+#endif
 };
 
 struct ib_uobject {