shmctl: split the work from copyin/copyout
authorAl Viro <viro@zeniv.linux.org.uk>
Sun, 9 Jul 2017 00:58:06 +0000 (20:58 -0400)
committerAl Viro <viro@zeniv.linux.org.uk>
Sun, 16 Jul 2017 00:46:41 +0000 (20:46 -0400)
Signed-off-by: Al Viro <viro@zeniv.linux.org.uk>
ipc/shm.c

index 28a444861a8f489fa46edca81878eedf0f55da7a..b4073c08d0e82c11856283bd3c825aee65fa6de4 100644 (file)
--- a/ipc/shm.c
+++ b/ipc/shm.c
@@ -813,23 +813,17 @@ static void shm_get_stat(struct ipc_namespace *ns, unsigned long *rss,
  * NOTE: no locks must be held, the rwsem is taken inside this function.
  */
 static int shmctl_down(struct ipc_namespace *ns, int shmid, int cmd,
-                      struct shmid_ds __user *buf, int version)
+                      struct shmid64_ds *shmid64)
 {
        struct kern_ipc_perm *ipcp;
-       struct shmid64_ds shmid64;
        struct shmid_kernel *shp;
        int err;
 
-       if (cmd == IPC_SET) {
-               if (copy_shmid_from_user(&shmid64, buf, version))
-                       return -EFAULT;
-       }
-
        down_write(&shm_ids(ns).rwsem);
        rcu_read_lock();
 
        ipcp = ipcctl_pre_down_nolock(ns, &shm_ids(ns), shmid, cmd,
-                                     &shmid64.shm_perm, 0);
+                                     &shmid64->shm_perm, 0);
        if (IS_ERR(ipcp)) {
                err = PTR_ERR(ipcp);
                goto out_unlock1;
@@ -849,7 +843,7 @@ static int shmctl_down(struct ipc_namespace *ns, int shmid, int cmd,
                goto out_up;
        case IPC_SET:
                ipc_lock_object(&shp->shm_perm);
-               err = ipc_update_perm(&shmid64.shm_perm, ipcp);
+               err = ipc_update_perm(&shmid64->shm_perm, ipcp);
                if (err)
                        goto out_unlock0;
                shp->shm_ctim = get_seconds();
@@ -868,212 +862,162 @@ out_up:
        return err;
 }
 
-static int shmctl_nolock(struct ipc_namespace *ns, int shmid,
-                        int cmd, int version, void __user *buf)
+static int shmctl_ipc_info(struct ipc_namespace *ns,
+                          struct shminfo64 *shminfo)
 {
-       int err;
-       struct shmid_kernel *shp;
-
-       /* preliminary security checks for *_INFO */
-       if (cmd == IPC_INFO || cmd == SHM_INFO) {
-               err = security_shm_shmctl(NULL, cmd);
-               if (err)
-                       return err;
-       }
-
-       switch (cmd) {
-       case IPC_INFO:
-       {
-               struct shminfo64 shminfo;
-
-               memset(&shminfo, 0, sizeof(shminfo));
-               shminfo.shmmni = shminfo.shmseg = ns->shm_ctlmni;
-               shminfo.shmmax = ns->shm_ctlmax;
-               shminfo.shmall = ns->shm_ctlall;
-
-               shminfo.shmmin = SHMMIN;
-               if (copy_shminfo_to_user(buf, &shminfo, version))
-                       return -EFAULT;
-
+       int err = security_shm_shmctl(NULL, IPC_INFO);
+       if (!err) {
+               memset(shminfo, 0, sizeof(*shminfo));
+               shminfo->shmmni = shminfo->shmseg = ns->shm_ctlmni;
+               shminfo->shmmax = ns->shm_ctlmax;
+               shminfo->shmall = ns->shm_ctlall;
+               shminfo->shmmin = SHMMIN;
                down_read(&shm_ids(ns).rwsem);
                err = ipc_get_maxid(&shm_ids(ns));
                up_read(&shm_ids(ns).rwsem);
-
                if (err < 0)
                        err = 0;
-               goto out;
        }
-       case SHM_INFO:
-       {
-               struct shm_info shm_info;
+       return err;
+}
 
-               memset(&shm_info, 0, sizeof(shm_info));
+static int shmctl_shm_info(struct ipc_namespace *ns,
+                          struct shm_info *shm_info)
+{
+       int err = security_shm_shmctl(NULL, SHM_INFO);
+       if (!err) {
+               memset(shm_info, 0, sizeof(*shm_info));
                down_read(&shm_ids(ns).rwsem);
-               shm_info.used_ids = shm_ids(ns).in_use;
-               shm_get_stat(ns, &shm_info.shm_rss, &shm_info.shm_swp);
-               shm_info.shm_tot = ns->shm_tot;
-               shm_info.swap_attempts = 0;
-               shm_info.swap_successes = 0;
+               shm_info->used_ids = shm_ids(ns).in_use;
+               shm_get_stat(ns, &shm_info->shm_rss, &shm_info->shm_swp);
+               shm_info->shm_tot = ns->shm_tot;
+               shm_info->swap_attempts = 0;
+               shm_info->swap_successes = 0;
                err = ipc_get_maxid(&shm_ids(ns));
                up_read(&shm_ids(ns).rwsem);
-               if (copy_to_user(buf, &shm_info, sizeof(shm_info))) {
-                       err = -EFAULT;
-                       goto out;
-               }
-
-               err = err < 0 ? 0 : err;
-               goto out;
+               if (err < 0)
+                       err = 0;
        }
-       case SHM_STAT:
-       case IPC_STAT:
-       {
-               struct shmid64_ds tbuf;
-               int result;
-
-               rcu_read_lock();
-               if (cmd == SHM_STAT) {
-                       shp = shm_obtain_object(ns, shmid);
-                       if (IS_ERR(shp)) {
-                               err = PTR_ERR(shp);
-                               goto out_unlock;
-                       }
-                       result = shp->shm_perm.id;
-               } else {
-                       shp = shm_obtain_object_check(ns, shmid);
-                       if (IS_ERR(shp)) {
-                               err = PTR_ERR(shp);
-                               goto out_unlock;
-                       }
-                       result = 0;
-               }
+       return err;
+}
 
-               err = -EACCES;
-               if (ipcperms(ns, &shp->shm_perm, S_IRUGO))
-                       goto out_unlock;
+static int shmctl_stat(struct ipc_namespace *ns, int shmid,
+                       int cmd, struct shmid64_ds *tbuf)
+{
+       struct shmid_kernel *shp;
+       int result;
+       int err;
 
-               err = security_shm_shmctl(shp, cmd);
-               if (err)
+       rcu_read_lock();
+       if (cmd == SHM_STAT) {
+               shp = shm_obtain_object(ns, shmid);
+               if (IS_ERR(shp)) {
+                       err = PTR_ERR(shp);
                        goto out_unlock;
+               }
+               result = shp->shm_perm.id;
+       } else {
+               shp = shm_obtain_object_check(ns, shmid);
+               if (IS_ERR(shp)) {
+                       err = PTR_ERR(shp);
+                       goto out_unlock;
+               }
+               result = 0;
+       }
 
-               memset(&tbuf, 0, sizeof(tbuf));
-               kernel_to_ipc64_perm(&shp->shm_perm, &tbuf.shm_perm);
-               tbuf.shm_segsz  = shp->shm_segsz;
-               tbuf.shm_atime  = shp->shm_atim;
-               tbuf.shm_dtime  = shp->shm_dtim;
-               tbuf.shm_ctime  = shp->shm_ctim;
-               tbuf.shm_cpid   = shp->shm_cprid;
-               tbuf.shm_lpid   = shp->shm_lprid;
-               tbuf.shm_nattch = shp->shm_nattch;
-               rcu_read_unlock();
+       err = -EACCES;
+       if (ipcperms(ns, &shp->shm_perm, S_IRUGO))
+               goto out_unlock;
 
-               if (copy_shmid_to_user(buf, &tbuf, version))
-                       err = -EFAULT;
-               else
-                       err = result;
-               goto out;
-       }
-       default:
-               return -EINVAL;
-       }
+       err = security_shm_shmctl(shp, cmd);
+       if (err)
+               goto out_unlock;
+
+       memset(tbuf, 0, sizeof(*tbuf));
+       kernel_to_ipc64_perm(&shp->shm_perm, &tbuf->shm_perm);
+       tbuf->shm_segsz = shp->shm_segsz;
+       tbuf->shm_atime = shp->shm_atim;
+       tbuf->shm_dtime = shp->shm_dtim;
+       tbuf->shm_ctime = shp->shm_ctim;
+       tbuf->shm_cpid  = shp->shm_cprid;
+       tbuf->shm_lpid  = shp->shm_lprid;
+       tbuf->shm_nattch = shp->shm_nattch;
+       rcu_read_unlock();
+       return result;
 
 out_unlock:
        rcu_read_unlock();
-out:
        return err;
 }
 
-SYSCALL_DEFINE3(shmctl, int, shmid, int, cmd, struct shmid_ds __user *, buf)
+static int shmctl_do_lock(struct ipc_namespace *ns, int shmid, int cmd)
 {
        struct shmid_kernel *shp;
-       int err, version;
-       struct ipc_namespace *ns;
-
-       if (cmd < 0 || shmid < 0)
-               return -EINVAL;
+       struct file *shm_file;
+       int err;
 
-       version = ipc_parse_version(&cmd);
-       ns = current->nsproxy->ipc_ns;
+       rcu_read_lock();
+       shp = shm_obtain_object_check(ns, shmid);
+       if (IS_ERR(shp)) {
+               err = PTR_ERR(shp);
+               goto out_unlock1;
+       }
 
-       switch (cmd) {
-       case IPC_INFO:
-       case SHM_INFO:
-       case SHM_STAT:
-       case IPC_STAT:
-               return shmctl_nolock(ns, shmid, cmd, version, buf);
-       case IPC_RMID:
-       case IPC_SET:
-               return shmctl_down(ns, shmid, cmd, buf, version);
-       case SHM_LOCK:
-       case SHM_UNLOCK:
-       {
-               struct file *shm_file;
+       audit_ipc_obj(&(shp->shm_perm));
+       err = security_shm_shmctl(shp, cmd);
+       if (err)
+               goto out_unlock1;
 
-               rcu_read_lock();
-               shp = shm_obtain_object_check(ns, shmid);
-               if (IS_ERR(shp)) {
-                       err = PTR_ERR(shp);
-                       goto out_unlock1;
-               }
+       ipc_lock_object(&shp->shm_perm);
 
-               audit_ipc_obj(&(shp->shm_perm));
-               err = security_shm_shmctl(shp, cmd);
-               if (err)
-                       goto out_unlock1;
+       /* check if shm_destroy() is tearing down shp */
+       if (!ipc_valid_object(&shp->shm_perm)) {
+               err = -EIDRM;
+               goto out_unlock0;
+       }
 
-               ipc_lock_object(&shp->shm_perm);
+       if (!ns_capable(ns->user_ns, CAP_IPC_LOCK)) {
+               kuid_t euid = current_euid();
 
-               /* check if shm_destroy() is tearing down shp */
-               if (!ipc_valid_object(&shp->shm_perm)) {
-                       err = -EIDRM;
+               if (!uid_eq(euid, shp->shm_perm.uid) &&
+                   !uid_eq(euid, shp->shm_perm.cuid)) {
+                       err = -EPERM;
                        goto out_unlock0;
                }
-
-               if (!ns_capable(ns->user_ns, CAP_IPC_LOCK)) {
-                       kuid_t euid = current_euid();
-
-                       if (!uid_eq(euid, shp->shm_perm.uid) &&
-                           !uid_eq(euid, shp->shm_perm.cuid)) {
-                               err = -EPERM;
-                               goto out_unlock0;
-                       }
-                       if (cmd == SHM_LOCK && !rlimit(RLIMIT_MEMLOCK)) {
-                               err = -EPERM;
-                               goto out_unlock0;
-                       }
+               if (cmd == SHM_LOCK && !rlimit(RLIMIT_MEMLOCK)) {
+                       err = -EPERM;
+                       goto out_unlock0;
                }
+       }
 
-               shm_file = shp->shm_file;
-               if (is_file_hugepages(shm_file))
-                       goto out_unlock0;
+       shm_file = shp->shm_file;
+       if (is_file_hugepages(shm_file))
+               goto out_unlock0;
 
-               if (cmd == SHM_LOCK) {
-                       struct user_struct *user = current_user();
+       if (cmd == SHM_LOCK) {
+               struct user_struct *user = current_user();
 
-                       err = shmem_lock(shm_file, 1, user);
-                       if (!err && !(shp->shm_perm.mode & SHM_LOCKED)) {
-                               shp->shm_perm.mode |= SHM_LOCKED;
-                               shp->mlock_user = user;
-                       }
-                       goto out_unlock0;
+               err = shmem_lock(shm_file, 1, user);
+               if (!err && !(shp->shm_perm.mode & SHM_LOCKED)) {
+                       shp->shm_perm.mode |= SHM_LOCKED;
+                       shp->mlock_user = user;
                }
+               goto out_unlock0;
+       }
 
-               /* SHM_UNLOCK */
-               if (!(shp->shm_perm.mode & SHM_LOCKED))
-                       goto out_unlock0;
-               shmem_lock(shm_file, 0, shp->mlock_user);
-               shp->shm_perm.mode &= ~SHM_LOCKED;
-               shp->mlock_user = NULL;
-               get_file(shm_file);
-               ipc_unlock_object(&shp->shm_perm);
-               rcu_read_unlock();
-               shmem_unlock_mapping(shm_file->f_mapping);
+       /* SHM_UNLOCK */
+       if (!(shp->shm_perm.mode & SHM_LOCKED))
+               goto out_unlock0;
+       shmem_lock(shm_file, 0, shp->mlock_user);
+       shp->shm_perm.mode &= ~SHM_LOCKED;
+       shp->mlock_user = NULL;
+       get_file(shm_file);
+       ipc_unlock_object(&shp->shm_perm);
+       rcu_read_unlock();
+       shmem_unlock_mapping(shm_file->f_mapping);
 
-               fput(shm_file);
-               return err;
-       }
-       default:
-               return -EINVAL;
-       }
+       fput(shm_file);
+       return err;
 
 out_unlock0:
        ipc_unlock_object(&shp->shm_perm);
@@ -1082,6 +1026,59 @@ out_unlock1:
        return err;
 }
 
+SYSCALL_DEFINE3(shmctl, int, shmid, int, cmd, struct shmid_ds __user *, buf)
+{
+       int err, version;
+       struct ipc_namespace *ns;
+       struct shmid64_ds tbuf;
+
+       if (cmd < 0 || shmid < 0)
+               return -EINVAL;
+
+       version = ipc_parse_version(&cmd);
+       ns = current->nsproxy->ipc_ns;
+
+       switch (cmd) {
+       case IPC_INFO: {
+               struct shminfo64 shminfo;
+               err = shmctl_ipc_info(ns, &shminfo);
+               if (err < 0)
+                       return err;
+               if (copy_shminfo_to_user(buf, &shminfo, version))
+                       err = -EFAULT;
+               return err;
+       }
+       case SHM_INFO: {
+               struct shm_info shm_info;
+               err = shmctl_shm_info(ns, &shm_info);
+               if (err < 0)
+                       return err;
+               if (copy_to_user(buf, &shm_info, sizeof(shm_info)))
+                       err = -EFAULT;
+               return err;
+       }
+       case SHM_STAT:
+       case IPC_STAT: {
+               err = shmctl_stat(ns, shmid, cmd, &tbuf);
+               if (err < 0)
+                       return err;
+               if (copy_shmid_to_user(buf, &tbuf, version))
+                       err = -EFAULT;
+               return err;
+       }
+       case IPC_SET:
+               if (copy_shmid_from_user(&tbuf, buf, version))
+                       return -EFAULT;
+       case IPC_RMID:
+               return shmctl_down(ns, shmid, cmd, &tbuf);
+       case SHM_LOCK:
+       case SHM_UNLOCK:
+               return shmctl_do_lock(ns, shmid, cmd);
+       default:
+               return -EINVAL;
+       }
+}
+
 /*
  * Fix shmaddr, allocate descriptor, map shm, add attach descriptor to lists.
  *