mm/gup: Implement the dev_pagemap() logic in the generic get_user_pages_fast() function
[GitHub/LineageOS/android_kernel_motorola_exynos9610.git] / mm / gup.c
index 9c047e951aa3d0399f331eb3914b09d207dc989b..e3d1e80424f495a944132aaf635e412b5a35db10 100644 (file)
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -226,6 +226,7 @@ struct page *follow_page_mask(struct vm_area_struct *vma,
                              unsigned int *page_mask)
 {
        pgd_t *pgd;
+       p4d_t *p4d;
        pud_t *pud;
        pmd_t *pmd;
        spinlock_t *ptl;
@@ -243,8 +244,13 @@ struct page *follow_page_mask(struct vm_area_struct *vma,
        pgd = pgd_offset(mm, address);
        if (pgd_none(*pgd) || unlikely(pgd_bad(*pgd)))
                return no_page_table(vma, flags);
-
-       pud = pud_offset(pgd, address);
+       p4d = p4d_offset(pgd, address);
+       if (p4d_none(*p4d))
+               return no_page_table(vma, flags);
+       BUILD_BUG_ON(p4d_huge(*p4d));
+       if (unlikely(p4d_bad(*p4d)))
+               return no_page_table(vma, flags);
+       pud = pud_offset(p4d, address);
        if (pud_none(*pud))
                return no_page_table(vma, flags);
        if (pud_huge(*pud) && vma->vm_flags & VM_HUGETLB) {
@@ -325,6 +331,7 @@ static int get_gate_page(struct mm_struct *mm, unsigned long address,
                struct page **page)
 {
        pgd_t *pgd;
+       p4d_t *p4d;
        pud_t *pud;
        pmd_t *pmd;
        pte_t *pte;
@@ -338,7 +345,9 @@ static int get_gate_page(struct mm_struct *mm, unsigned long address,
        else
                pgd = pgd_offset_gate(mm, address);
        BUG_ON(pgd_none(*pgd));
-       pud = pud_offset(pgd, address);
+       p4d = p4d_offset(pgd, address);
+       BUG_ON(p4d_none(*p4d));
+       pud = pud_offset(p4d, address);
        BUG_ON(pud_none(*pud));
        pmd = pmd_offset(pud, address);
        if (pmd_none(*pmd))
@@ -1180,34 +1189,57 @@ struct page *get_dump_page(unsigned long addr)
  */
 #ifdef CONFIG_HAVE_GENERIC_RCU_GUP
 
+#ifndef gup_get_pte
+/*
+ * We assume that the PTE can be read atomically. If this is not the case for
+ * your architecture, please provide the helper.
+ */
+static inline pte_t gup_get_pte(pte_t *ptep)
+{
+       return READ_ONCE(*ptep);
+}
+#endif
+
+static void undo_dev_pagemap(int *nr, int nr_start, struct page **pages)
+{
+       while ((*nr) - nr_start) {
+               struct page *page = pages[--(*nr)];
+
+               ClearPageReferenced(page);
+               put_page(page);
+       }
+}
+
 #ifdef __HAVE_ARCH_PTE_SPECIAL
 static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
                         int write, struct page **pages, int *nr)
 {
+       struct dev_pagemap *pgmap = NULL;
+       int nr_start = *nr, ret = 0;
        pte_t *ptep, *ptem;
-       int ret = 0;
 
        ptem = ptep = pte_offset_map(&pmd, addr);
        do {
-               /*
-                * In the line below we are assuming that the pte can be read
-                * atomically. If this is not the case for your architecture,
-                * please wrap this in a helper function!
-                *
-                * for an example see gup_get_pte in arch/x86/mm/gup.c
-                */
-               pte_t pte = READ_ONCE(*ptep);
+               pte_t pte = gup_get_pte(ptep);
                struct page *head, *page;
 
                /*
                 * Similar to the PMD case below, NUMA hinting must take slow
                 * path using the pte_protnone check.
                 */
-               if (!pte_present(pte) || pte_special(pte) ||
-                       pte_protnone(pte) || (write && !pte_write(pte)))
+               if (pte_protnone(pte))
                        goto pte_unmap;
 
-               if (!arch_pte_access_permitted(pte, write))
+               if (!pte_access_permitted(pte, write))
+                       goto pte_unmap;
+
+               if (pte_devmap(pte)) {
+                       pgmap = get_dev_pagemap(pte_pfn(pte), pgmap);
+                       if (unlikely(!pgmap)) {
+                               undo_dev_pagemap(nr, nr_start, pages);
+                               goto pte_unmap;
+                       }
+               } else if (pte_special(pte))
                        goto pte_unmap;
 
                VM_BUG_ON(!pfn_valid(pte_pfn(pte)));
@@ -1223,6 +1255,9 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
                }
 
                VM_BUG_ON_PAGE(compound_head(page) != head, page);
+
+               put_dev_pagemap(pgmap);
+               SetPageReferenced(page);
                pages[*nr] = page;
                (*nr)++;
 
@@ -1252,15 +1287,76 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
 }
 #endif /* __HAVE_ARCH_PTE_SPECIAL */
 
+#ifdef __HAVE_ARCH_PTE_DEVMAP
+static int __gup_device_huge(unsigned long pfn, unsigned long addr,
+               unsigned long end, struct page **pages, int *nr)
+{
+       int nr_start = *nr;
+       struct dev_pagemap *pgmap = NULL;
+
+       do {
+               struct page *page = pfn_to_page(pfn);
+
+               pgmap = get_dev_pagemap(pfn, pgmap);
+               if (unlikely(!pgmap)) {
+                       undo_dev_pagemap(nr, nr_start, pages);
+                       return 0;
+               }
+               SetPageReferenced(page);
+               pages[*nr] = page;
+               get_page(page);
+               put_dev_pagemap(pgmap);
+               (*nr)++;
+               pfn++;
+       } while (addr += PAGE_SIZE, addr != end);
+       return 1;
+}
+
+static int __gup_device_huge_pmd(pmd_t pmd, unsigned long addr,
+               unsigned long end, struct page **pages, int *nr)
+{
+       unsigned long fault_pfn;
+
+       fault_pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
+       return __gup_device_huge(fault_pfn, addr, end, pages, nr);
+}
+
+static int __gup_device_huge_pud(pud_t pud, unsigned long addr,
+               unsigned long end, struct page **pages, int *nr)
+{
+       unsigned long fault_pfn;
+
+       fault_pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
+       return __gup_device_huge(fault_pfn, addr, end, pages, nr);
+}
+#else
+static int __gup_device_huge_pmd(pmd_t pmd, unsigned long addr,
+               unsigned long end, struct page **pages, int *nr)
+{
+       BUILD_BUG();
+       return 0;
+}
+
+static int __gup_device_huge_pud(pud_t pud, unsigned long addr,
+               unsigned long end, struct page **pages, int *nr)
+{
+       BUILD_BUG();
+       return 0;
+}
+#endif
+
 static int gup_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
                unsigned long end, int write, struct page **pages, int *nr)
 {
        struct page *head, *page;
        int refs;
 
-       if (write && !pmd_write(orig))
+       if (!pmd_access_permitted(orig, write))
                return 0;
 
+       if (pmd_devmap(orig))
+               return __gup_device_huge_pmd(orig, addr, end, pages, nr);
+
        refs = 0;
        head = pmd_page(orig);
        page = head + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
@@ -1284,6 +1380,7 @@ static int gup_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
                return 0;
        }
 
+       SetPageReferenced(head);
        return 1;
 }
 
@@ -1293,9 +1390,12 @@ static int gup_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
        struct page *head, *page;
        int refs;
 
-       if (write && !pud_write(orig))
+       if (!pud_access_permitted(orig, write))
                return 0;
 
+       if (pud_devmap(orig))
+               return __gup_device_huge_pud(orig, addr, end, pages, nr);
+
        refs = 0;
        head = pud_page(orig);
        page = head + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
@@ -1319,6 +1419,7 @@ static int gup_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
                return 0;
        }
 
+       SetPageReferenced(head);
        return 1;
 }
 
@@ -1329,9 +1430,10 @@ static int gup_huge_pgd(pgd_t orig, pgd_t *pgdp, unsigned long addr,
        int refs;
        struct page *head, *page;
 
-       if (write && !pgd_write(orig))
+       if (!pgd_access_permitted(orig, write))
                return 0;
 
+       BUILD_BUG_ON(pgd_devmap(orig));
        refs = 0;
        head = pgd_page(orig);
        page = head + ((addr & ~PGDIR_MASK) >> PAGE_SHIFT);
@@ -1355,6 +1457,7 @@ static int gup_huge_pgd(pgd_t orig, pgd_t *pgdp, unsigned long addr,
                return 0;
        }
 
+       SetPageReferenced(head);
        return 1;
 }
 
@@ -1400,13 +1503,13 @@ static int gup_pmd_range(pud_t pud, unsigned long addr, unsigned long end,
        return 1;
 }
 
-static int gup_pud_range(pgd_t pgd, unsigned long addr, unsigned long end,
+static int gup_pud_range(p4d_t p4d, unsigned long addr, unsigned long end,
                         int write, struct page **pages, int *nr)
 {
        unsigned long next;
        pud_t *pudp;
 
-       pudp = pud_offset(&pgd, addr);
+       pudp = pud_offset(&p4d, addr);
        do {
                pud_t pud = READ_ONCE(*pudp);
 
@@ -1428,6 +1531,31 @@ static int gup_pud_range(pgd_t pgd, unsigned long addr, unsigned long end,
        return 1;
 }
 
+static int gup_p4d_range(pgd_t pgd, unsigned long addr, unsigned long end,
+                        int write, struct page **pages, int *nr)
+{
+       unsigned long next;
+       p4d_t *p4dp;
+
+       p4dp = p4d_offset(&pgd, addr);
+       do {
+               p4d_t p4d = READ_ONCE(*p4dp);
+
+               next = p4d_addr_end(addr, end);
+               if (p4d_none(p4d))
+                       return 0;
+               BUILD_BUG_ON(p4d_huge(p4d));
+               if (unlikely(is_hugepd(__hugepd(p4d_val(p4d))))) {
+                       if (!gup_huge_pd(__hugepd(p4d_val(p4d)), addr,
+                                        P4D_SHIFT, next, write, pages, nr))
+                               return 0;
+               } else if (!gup_pud_range(p4d, addr, next, write, pages, nr))
+                       return 0;
+       } while (p4dp++, addr = next, addr != end);
+
+       return 1;
+}
+
 /*
  * Like get_user_pages_fast() except it's IRQ-safe in that it won't fall back to
  * the regular GUP. It will only return non-negative values.
@@ -1478,7 +1606,7 @@ int __get_user_pages_fast(unsigned long start, int nr_pages, int write,
                        if (!gup_huge_pd(__hugepd(pgd_val(pgd)), addr,
                                         PGDIR_SHIFT, next, write, pages, &nr))
                                break;
-               } else if (!gup_pud_range(pgd, addr, next, write, pages, &nr))
+               } else if (!gup_p4d_range(pgd, addr, next, write, pages, &nr))
                        break;
        } while (pgdp++, addr = next, addr != end);
        local_irq_restore(flags);