net/9p: switch the guts of p9_client_{read,write}() to iov_iter
authorAl Viro <viro@zeniv.linux.org.uk>
Wed, 1 Apr 2015 23:57:53 +0000 (19:57 -0400)
committerAl Viro <viro@zeniv.linux.org.uk>
Sun, 12 Apr 2015 02:28:25 +0000 (22:28 -0400)
... and have get_user_pages_fast() mapping fewer pages than requested
to generate a short read/write.

Signed-off-by: Al Viro <viro@zeniv.linux.org.uk>
include/net/9p/transport.h
net/9p/client.c
net/9p/protocol.c
net/9p/trans_virtio.c

index 2a25dec3021166d5aba52ad155e8ca01e0b1570e..5122b5e40f78f1aec1b30aa17d13dd68558d9f4f 100644 (file)
@@ -61,7 +61,7 @@ struct p9_trans_module {
        int (*cancel) (struct p9_client *, struct p9_req_t *req);
        int (*cancelled)(struct p9_client *, struct p9_req_t *req);
        int (*zc_request)(struct p9_client *, struct p9_req_t *,
-                         char *, char *, int , int, int, int);
+                         struct iov_iter *, struct iov_iter *, int , int, int);
 };
 
 void v9fs_register_trans(struct p9_trans_module *m);
index e86a9bea1d160ccc1a739eee576d0bdbf0483956..9ef5d85f082f5716651f66a8309ffd6680af8e9b 100644 (file)
@@ -34,6 +34,7 @@
 #include <linux/slab.h>
 #include <linux/sched.h>
 #include <linux/uaccess.h>
+#include <linux/uio.h>
 #include <net/9p/9p.h>
 #include <linux/parser.h>
 #include <net/9p/client.h>
@@ -555,7 +556,7 @@ out_err:
  */
 
 static int p9_check_zc_errors(struct p9_client *c, struct p9_req_t *req,
-                             char *uidata, int in_hdrlen, int kern_buf)
+                             struct iov_iter *uidata, int in_hdrlen)
 {
        int err;
        int ecode;
@@ -591,16 +592,11 @@ static int p9_check_zc_errors(struct p9_client *c, struct p9_req_t *req,
                ename = &req->rc->sdata[req->rc->offset];
                if (len > inline_len) {
                        /* We have error in external buffer */
-                       if (kern_buf) {
-                               memcpy(ename + inline_len, uidata,
-                                      len - inline_len);
-                       } else {
-                               err = copy_from_user(ename + inline_len,
-                                                    uidata, len - inline_len);
-                               if (err) {
-                                       err = -EFAULT;
-                                       goto out_err;
-                               }
+                       err = copy_from_iter(ename + inline_len,
+                                            len - inline_len, uidata);
+                       if (err != len - inline_len) {
+                               err = -EFAULT;
+                               goto out_err;
                        }
                }
                ename = NULL;
@@ -806,8 +802,8 @@ reterr:
  * p9_client_zc_rpc - issue a request and wait for a response
  * @c: client session
  * @type: type of request
- * @uidata: user bffer that should be ued for zero copy read
- * @uodata: user buffer that shoud be user for zero copy write
+ * @uidata: destination for zero copy read
+ * @uodata: source for zero copy write
  * @inlen: read buffer size
  * @olen: write buffer size
  * @hdrlen: reader header size, This is the size of response protocol data
@@ -816,9 +812,10 @@ reterr:
  * Returns request structure (which client must free using p9_free_req)
  */
 static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type,
-                                        char *uidata, char *uodata,
+                                        struct iov_iter *uidata,
+                                        struct iov_iter *uodata,
                                         int inlen, int olen, int in_hdrlen,
-                                        int kern_buf, const char *fmt, ...)
+                                        const char *fmt, ...)
 {
        va_list ap;
        int sigpending, err;
@@ -841,12 +838,8 @@ static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type,
        } else
                sigpending = 0;
 
-       /* If we are called with KERNEL_DS force kern_buf */
-       if (segment_eq(get_fs(), KERNEL_DS))
-               kern_buf = 1;
-
        err = c->trans_mod->zc_request(c, req, uidata, uodata,
-                                      inlen, olen, in_hdrlen, kern_buf);
+                                      inlen, olen, in_hdrlen);
        if (err < 0) {
                if (err == -EIO)
                        c->status = Disconnected;
@@ -876,7 +869,7 @@ static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type,
        if (err < 0)
                goto reterr;
 
-       err = p9_check_zc_errors(c, req, uidata, in_hdrlen, kern_buf);
+       err = p9_check_zc_errors(c, req, uidata, in_hdrlen);
        trace_9p_client_res(c, type, req->rc->tag, err);
        if (!err)
                return req;
@@ -1545,11 +1538,24 @@ p9_client_read(struct p9_fid *fid, char *data, char __user *udata, u64 offset,
                                                                u32 count)
 {
        char *dataptr;
-       int kernel_buf = 0;
        struct p9_req_t *req;
        struct p9_client *clnt;
        int err, rsize, non_zc = 0;
-
+       struct iov_iter to;
+       union {
+               struct kvec kv;
+               struct iovec iov;
+       } v;
+
+       if (data) {
+               v.kv.iov_base = data;
+               v.kv.iov_len = count;
+               iov_iter_kvec(&to, ITER_KVEC | READ, &v.kv, 1, count);
+       } else {
+               v.iov.iov_base = udata;
+               v.iov.iov_len = count;
+               iov_iter_init(&to, READ, &v.iov, 1, count);
+       }
 
        p9_debug(P9_DEBUG_9P, ">>> TREAD fid %d offset %llu %d\n",
                   fid->fid, (unsigned long long) offset, count);
@@ -1565,18 +1571,12 @@ p9_client_read(struct p9_fid *fid, char *data, char __user *udata, u64 offset,
 
        /* Don't bother zerocopy for small IO (< 1024) */
        if (clnt->trans_mod->zc_request && rsize > 1024) {
-               char *indata;
-               if (data) {
-                       kernel_buf = 1;
-                       indata = data;
-               } else
-                       indata = (__force char *)udata;
                /*
                 * response header len is 11
                 * PDU Header(7) + IO Size (4)
                 */
-               req = p9_client_zc_rpc(clnt, P9_TREAD, indata, NULL, rsize, 0,
-                                      11, kernel_buf, "dqd", fid->fid,
+               req = p9_client_zc_rpc(clnt, P9_TREAD, &to, NULL, rsize, 0,
+                                      11, "dqd", fid->fid,
                                       offset, rsize);
        } else {
                non_zc = 1;
@@ -1596,16 +1596,9 @@ p9_client_read(struct p9_fid *fid, char *data, char __user *udata, u64 offset,
 
        p9_debug(P9_DEBUG_9P, "<<< RREAD count %d\n", count);
 
-       if (non_zc) {
-               if (data) {
-                       memmove(data, dataptr, count);
-               } else {
-                       err = copy_to_user(udata, dataptr, count);
-                       if (err) {
-                               err = -EFAULT;
-                               goto free_and_error;
-                       }
-               }
+       if (non_zc && copy_to_iter(dataptr, count, &to) != count) {
+               err = -EFAULT;
+               goto free_and_error;
        }
        p9_free_req(clnt, req);
        return count;
@@ -1622,9 +1615,23 @@ p9_client_write(struct p9_fid *fid, char *data, const char __user *udata,
                                                        u64 offset, u32 count)
 {
        int err, rsize;
-       int kernel_buf = 0;
        struct p9_client *clnt;
        struct p9_req_t *req;
+       struct iov_iter from;
+       union {
+               struct kvec kv;
+               struct iovec iov;
+       } v;
+
+       if (data) {
+               v.kv.iov_base = data;
+               v.kv.iov_len = count;
+               iov_iter_kvec(&from, ITER_KVEC | WRITE, &v.kv, 1, count);
+       } else {
+               v.iov.iov_base = udata;
+               v.iov.iov_len = count;
+               iov_iter_init(&from, WRITE, &v.iov, 1, count);
+       }
 
        p9_debug(P9_DEBUG_9P, ">>> TWRITE fid %d offset %llu count %d\n",
                                fid->fid, (unsigned long long) offset, count);
@@ -1640,22 +1647,12 @@ p9_client_write(struct p9_fid *fid, char *data, const char __user *udata,
 
        /* Don't bother zerocopy for small IO (< 1024) */
        if (clnt->trans_mod->zc_request && rsize > 1024) {
-               char *odata;
-               if (data) {
-                       kernel_buf = 1;
-                       odata = data;
-               } else
-                       odata = (char *)udata;
-               req = p9_client_zc_rpc(clnt, P9_TWRITE, NULL, odata, 0, rsize,
-                                      P9_ZC_HDR_SZ, kernel_buf, "dqd",
+               req = p9_client_zc_rpc(clnt, P9_TWRITE, NULL, &from, 0, rsize,
+                                      P9_ZC_HDR_SZ, "dqd",
                                       fid->fid, offset, rsize);
        } else {
-               if (data)
-                       req = p9_client_rpc(clnt, P9_TWRITE, "dqD", fid->fid,
-                                           offset, rsize, data);
-               else
-                       req = p9_client_rpc(clnt, P9_TWRITE, "dqU", fid->fid,
-                                           offset, rsize, udata);
+               req = p9_client_rpc(clnt, P9_TWRITE, "dqV", fid->fid,
+                                           offset, rsize, &from);
        }
        if (IS_ERR(req)) {
                err = PTR_ERR(req);
@@ -2068,6 +2065,10 @@ int p9_client_readdir(struct p9_fid *fid, char *data, u32 count, u64 offset)
        struct p9_client *clnt;
        struct p9_req_t *req;
        char *dataptr;
+       struct kvec kv = {.iov_base = data, .iov_len = count};
+       struct iov_iter to;
+
+       iov_iter_kvec(&to, READ | ITER_KVEC, &kv, 1, count);
 
        p9_debug(P9_DEBUG_9P, ">>> TREADDIR fid %d offset %llu count %d\n",
                                fid->fid, (unsigned long long) offset, count);
@@ -2088,8 +2089,8 @@ int p9_client_readdir(struct p9_fid *fid, char *data, u32 count, u64 offset)
                 * response header len is 11
                 * PDU Header(7) + IO Size (4)
                 */
-               req = p9_client_zc_rpc(clnt, P9_TREADDIR, data, NULL, rsize, 0,
-                                      11, 1, "dqd", fid->fid, offset, rsize);
+               req = p9_client_zc_rpc(clnt, P9_TREADDIR, &to, NULL, rsize, 0,
+                                      11, "dqd", fid->fid, offset, rsize);
        } else {
                non_zc = 1;
                req = p9_client_rpc(clnt, P9_TREADDIR, "dqd", fid->fid,
index ab9127ec5b7a6881e7dd2116e49819186675562f..e9d0f0c1a04827f0d1cc0f554f6b4dfaabb2c414 100644 (file)
@@ -33,6 +33,7 @@
 #include <linux/sched.h>
 #include <linux/stddef.h>
 #include <linux/types.h>
+#include <linux/uio.h>
 #include <net/9p/9p.h>
 #include <net/9p/client.h>
 #include "protocol.h"
@@ -69,10 +70,11 @@ static size_t pdu_write(struct p9_fcall *pdu, const void *data, size_t size)
 }
 
 static size_t
-pdu_write_u(struct p9_fcall *pdu, const char __user *udata, size_t size)
+pdu_write_u(struct p9_fcall *pdu, struct iov_iter *from, size_t size)
 {
        size_t len = min(pdu->capacity - pdu->size, size);
-       if (copy_from_user(&pdu->sdata[pdu->size], udata, len))
+       struct iov_iter i = *from;
+       if (copy_from_iter(&pdu->sdata[pdu->size], len, &i) != len)
                len = 0;
 
        pdu->size += len;
@@ -437,23 +439,13 @@ p9pdu_vwritef(struct p9_fcall *pdu, int proto_version, const char *fmt,
                                                 stbuf->extension, stbuf->n_uid,
                                                 stbuf->n_gid, stbuf->n_muid);
                        } break;
-               case 'D':{
-                               uint32_t count = va_arg(ap, uint32_t);
-                               const void *data = va_arg(ap, const void *);
-
-                               errcode = p9pdu_writef(pdu, proto_version, "d",
-                                                                       count);
-                               if (!errcode && pdu_write(pdu, data, count))
-                                       errcode = -EFAULT;
-                       }
-                       break;
-               case 'U':{
+               case 'V':{
                                int32_t count = va_arg(ap, int32_t);
-                               const char __user *udata =
-                                               va_arg(ap, const void __user *);
+                               struct iov_iter *from =
+                                               va_arg(ap, struct iov_iter *);
                                errcode = p9pdu_writef(pdu, proto_version, "d",
                                                                        count);
-                               if (!errcode && pdu_write_u(pdu, udata, count))
+                               if (!errcode && pdu_write_u(pdu, from, count))
                                        errcode = -EFAULT;
                        }
                        break;
index 36a1a739ad68ff57eace5ba4bc4166faf12c485b..e62bcbbabb5e3cd43717f7980fce3e280aba3ded 100644 (file)
@@ -217,15 +217,15 @@ static int p9_virtio_cancel(struct p9_client *client, struct p9_req_t *req)
  * @start: which segment of the sg_list to start at
  * @pdata: a list of pages to add into sg.
  * @nr_pages: number of pages to pack into the scatter/gather list
- * @data: data to pack into scatter/gather list
+ * @offs: amount of data in the beginning of first page _not_ to pack
  * @count: amount of data to pack into the scatter/gather list
  */
 static int
 pack_sg_list_p(struct scatterlist *sg, int start, int limit,
-              struct page **pdata, int nr_pages, char *data, int count)
+              struct page **pdata, int nr_pages, size_t offs, int count)
 {
        int i = 0, s;
-       int data_off;
+       int data_off = offs;
        int index = start;
 
        BUG_ON(nr_pages > (limit - start));
@@ -233,16 +233,14 @@ pack_sg_list_p(struct scatterlist *sg, int start, int limit,
         * if the first page doesn't start at
         * page boundary find the offset
         */
-       data_off = offset_in_page(data);
        while (nr_pages) {
-               s = rest_of_page(data);
+               s = PAGE_SIZE - data_off;
                if (s > count)
                        s = count;
                /* Make sure we don't terminate early. */
                sg_unmark_end(&sg[index]);
                sg_set_page(&sg[index++], pdata[i++], s, data_off);
                data_off = 0;
-               data += s;
                count -= s;
                nr_pages--;
        }
@@ -314,11 +312,20 @@ req_retry:
 }
 
 static int p9_get_mapped_pages(struct virtio_chan *chan,
-                              struct page **pages, char *data,
-                              int nr_pages, int write, int kern_buf)
+                              struct page ***pages,
+                              struct iov_iter *data,
+                              int count,
+                              size_t *offs,
+                              int *need_drop)
 {
+       int nr_pages;
        int err;
-       if (!kern_buf) {
+
+       if (!iov_iter_count(data))
+               return 0;
+
+       if (!(data->type & ITER_KVEC)) {
+               int n;
                /*
                 * We allow only p9_max_pages pinned. We wait for the
                 * Other zc request to finish here
@@ -329,26 +336,49 @@ static int p9_get_mapped_pages(struct virtio_chan *chan,
                        if (err == -ERESTARTSYS)
                                return err;
                }
-               err = p9_payload_gup(data, &nr_pages, pages, write);
-               if (err < 0)
-                       return err;
+               n = iov_iter_get_pages_alloc(data, pages, count, offs);
+               if (n < 0)
+                       return n;
+               *need_drop = 1;
+               nr_pages = DIV_ROUND_UP(n + *offs, PAGE_SIZE);
                atomic_add(nr_pages, &vp_pinned);
+               return n;
        } else {
                /* kernel buffer, no need to pin pages */
-               int s, index = 0;
-               int count = nr_pages;
-               while (nr_pages) {
-                       s = rest_of_page(data);
-                       if (is_vmalloc_addr(data))
-                               pages[index++] = vmalloc_to_page(data);
+               int index;
+               size_t len;
+               void *p;
+
+               /* we'd already checked that it's non-empty */
+               while (1) {
+                       len = iov_iter_single_seg_count(data);
+                       if (likely(len)) {
+                               p = data->kvec->iov_base + data->iov_offset;
+                               break;
+                       }
+                       iov_iter_advance(data, 0);
+               }
+               if (len > count)
+                       len = count;
+
+               nr_pages = DIV_ROUND_UP((unsigned long)p + len, PAGE_SIZE) -
+                          (unsigned long)p / PAGE_SIZE;
+
+               *pages = kmalloc(sizeof(struct page *) * nr_pages, GFP_NOFS);
+               if (!*pages)
+                       return -ENOMEM;
+
+               *need_drop = 0;
+               p -= (*offs = (unsigned long)p % PAGE_SIZE);
+               for (index = 0; index < nr_pages; index++) {
+                       if (is_vmalloc_addr(p))
+                               (*pages)[index] = vmalloc_to_page(p);
                        else
-                               pages[index++] = kmap_to_page(data);
-                       data += s;
-                       nr_pages--;
+                               (*pages)[index] = kmap_to_page(p);
+                       p += PAGE_SIZE;
                }
-               nr_pages = count;
+               return len;
        }
-       return nr_pages;
 }
 
 /**
@@ -364,8 +394,8 @@ static int p9_get_mapped_pages(struct virtio_chan *chan,
  */
 static int
 p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req,
-                    char *uidata, char *uodata, int inlen,
-                    int outlen, int in_hdr_len, int kern_buf)
+                    struct iov_iter *uidata, struct iov_iter *uodata,
+                    int inlen, int outlen, int in_hdr_len)
 {
        int in, out, err, out_sgs, in_sgs;
        unsigned long flags;
@@ -373,41 +403,32 @@ p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req,
        struct page **in_pages = NULL, **out_pages = NULL;
        struct virtio_chan *chan = client->trans;
        struct scatterlist *sgs[4];
+       size_t offs;
+       int need_drop = 0;
 
        p9_debug(P9_DEBUG_TRANS, "virtio request\n");
 
        if (uodata) {
-               out_nr_pages = p9_nr_pages(uodata, outlen);
-               out_pages = kmalloc(sizeof(struct page *) * out_nr_pages,
-                                   GFP_NOFS);
-               if (!out_pages) {
-                       err = -ENOMEM;
-                       goto err_out;
-               }
-               out_nr_pages = p9_get_mapped_pages(chan, out_pages, uodata,
-                                                  out_nr_pages, 0, kern_buf);
-               if (out_nr_pages < 0) {
-                       err = out_nr_pages;
-                       kfree(out_pages);
-                       out_pages = NULL;
-                       goto err_out;
+               int n = p9_get_mapped_pages(chan, &out_pages, uodata,
+                                           outlen, &offs, &need_drop);
+               if (n < 0)
+                       return n;
+               out_nr_pages = DIV_ROUND_UP(n + offs, PAGE_SIZE);
+               if (n != outlen) {
+                       __le32 v = cpu_to_le32(n);
+                       memcpy(&req->tc->sdata[req->tc->size - 4], &v, 4);
+                       outlen = n;
                }
-       }
-       if (uidata) {
-               in_nr_pages = p9_nr_pages(uidata, inlen);
-               in_pages = kmalloc(sizeof(struct page *) * in_nr_pages,
-                                  GFP_NOFS);
-               if (!in_pages) {
-                       err = -ENOMEM;
-                       goto err_out;
-               }
-               in_nr_pages = p9_get_mapped_pages(chan, in_pages, uidata,
-                                                 in_nr_pages, 1, kern_buf);
-               if (in_nr_pages < 0) {
-                       err = in_nr_pages;
-                       kfree(in_pages);
-                       in_pages = NULL;
-                       goto err_out;
+       } else if (uidata) {
+               int n = p9_get_mapped_pages(chan, &in_pages, uidata,
+                                           inlen, &offs, &need_drop);
+               if (n < 0)
+                       return n;
+               in_nr_pages = DIV_ROUND_UP(n + offs, PAGE_SIZE);
+               if (n != inlen) {
+                       __le32 v = cpu_to_le32(n);
+                       memcpy(&req->tc->sdata[req->tc->size - 4], &v, 4);
+                       inlen = n;
                }
        }
        req->status = REQ_STATUS_SENT;
@@ -426,7 +447,7 @@ req_retry_pinned:
        if (out_pages) {
                sgs[out_sgs++] = chan->sg + out;
                out += pack_sg_list_p(chan->sg, out, VIRTQUEUE_NUM,
-                                     out_pages, out_nr_pages, uodata, outlen);
+                                     out_pages, out_nr_pages, offs, outlen);
        }
                
        /*
@@ -444,7 +465,7 @@ req_retry_pinned:
        if (in_pages) {
                sgs[out_sgs + in_sgs++] = chan->sg + out + in;
                in += pack_sg_list_p(chan->sg, out + in, VIRTQUEUE_NUM,
-                                    in_pages, in_nr_pages, uidata, inlen);
+                                    in_pages, in_nr_pages, offs, inlen);
        }
 
        BUG_ON(out_sgs + in_sgs > ARRAY_SIZE(sgs));
@@ -478,7 +499,7 @@ req_retry_pinned:
         * Non kernel buffers are pinned, unpin them
         */
 err_out:
-       if (!kern_buf) {
+       if (need_drop) {
                if (in_pages) {
                        p9_release_pages(in_pages, in_nr_pages);
                        atomic_sub(in_nr_pages, &vp_pinned);