Input: tsdev - implement proper locking
authorDmitry Torokhov <dmitry.torokhov@gmail.com>
Thu, 30 Aug 2007 04:22:39 +0000 (00:22 -0400)
committerDmitry Torokhov <dmitry.torokhov@gmail.com>
Thu, 30 Aug 2007 04:22:39 +0000 (00:22 -0400)
Signed-off-by: Dmitry Torokhov <dtor@mail.ru>
drivers/input/tsdev.c

index d2f882e98e5e21c0c9c23b5d8c3e072883de052a..c189f1dd569f42f3451eae6e8bda519f47fc7fed 100644 (file)
@@ -112,6 +112,8 @@ struct tsdev {
        struct input_handle handle;
        wait_queue_head_t wait;
        struct list_head client_list;
+       spinlock_t client_lock; /* protects client_list */
+       struct mutex mutex;
        struct device dev;
 
        int x, y, pressure;
@@ -122,8 +124,9 @@ struct tsdev_client {
        struct fasync_struct *fasync;
        struct list_head node;
        struct tsdev *tsdev;
+       struct ts_event buffer[TSDEV_BUFFER_SIZE];
        int head, tail;
-       struct ts_event event[TSDEV_BUFFER_SIZE];
+       spinlock_t buffer_lock; /* protects access to buffer, head and tail */
        int raw;
 };
 
@@ -137,6 +140,7 @@ struct tsdev_client {
 #define TS_SET_CAL     _IOW(IOC_H3600_TS_MAGIC, 11, struct ts_calibration)
 
 static struct tsdev *tsdev_table[TSDEV_MINORS/2];
+static DEFINE_MUTEX(tsdev_table_mutex);
 
 static int tsdev_fasync(int fd, struct file *file, int on)
 {
@@ -144,9 +148,91 @@ static int tsdev_fasync(int fd, struct file *file, int on)
        int retval;
 
        retval = fasync_helper(fd, file, on, &client->fasync);
+
        return retval < 0 ? retval : 0;
 }
 
+static void tsdev_free(struct device *dev)
+{
+       struct tsdev *tsdev = container_of(dev, struct tsdev, dev);
+
+       kfree(tsdev);
+}
+
+static void tsdev_attach_client(struct tsdev *tsdev, struct tsdev_client *client)
+{
+       spin_lock(&tsdev->client_lock);
+       list_add_tail_rcu(&client->node, &tsdev->client_list);
+       spin_unlock(&tsdev->client_lock);
+       synchronize_sched();
+}
+
+static void tsdev_detach_client(struct tsdev *tsdev, struct tsdev_client *client)
+{
+       spin_lock(&tsdev->client_lock);
+       list_del_rcu(&client->node);
+       spin_unlock(&tsdev->client_lock);
+       synchronize_sched();
+}
+
+static int tsdev_open_device(struct tsdev *tsdev)
+{
+       int retval;
+
+       retval = mutex_lock_interruptible(&tsdev->mutex);
+       if (retval)
+               return retval;
+
+       if (!tsdev->exist)
+               retval = -ENODEV;
+       else if (!tsdev->open++)
+               retval = input_open_device(&tsdev->handle);
+
+       mutex_unlock(&tsdev->mutex);
+       return retval;
+}
+
+static void tsdev_close_device(struct tsdev *tsdev)
+{
+       mutex_lock(&tsdev->mutex);
+
+       if (tsdev->exist && !--tsdev->open)
+               input_close_device(&tsdev->handle);
+
+       mutex_unlock(&tsdev->mutex);
+}
+
+/*
+ * Wake up users waiting for IO so they can disconnect from
+ * dead device.
+ */
+static void tsdev_hangup(struct tsdev *tsdev)
+{
+       struct tsdev_client *client;
+
+       spin_lock(&tsdev->client_lock);
+       list_for_each_entry(client, &tsdev->client_list, node)
+               kill_fasync(&client->fasync, SIGIO, POLL_HUP);
+       spin_unlock(&tsdev->client_lock);
+
+       wake_up_interruptible(&tsdev->wait);
+}
+
+static int tsdev_release(struct inode *inode, struct file *file)
+{
+       struct tsdev_client *client = file->private_data;
+       struct tsdev *tsdev = client->tsdev;
+
+       tsdev_fasync(-1, file, 0);
+       tsdev_detach_client(tsdev, client);
+       kfree(client);
+
+       tsdev_close_device(tsdev);
+       put_device(&tsdev->dev);
+
+       return 0;
+}
+
 static int tsdev_open(struct inode *inode, struct file *file)
 {
        int i = iminor(inode) - TSDEV_MINOR_BASE;
@@ -161,11 +247,16 @@ static int tsdev_open(struct inode *inode, struct file *file)
        if (i >= TSDEV_MINORS)
                return -ENODEV;
 
+       error = mutex_lock_interruptible(&tsdev_table_mutex);
+       if (error)
+               return error;
        tsdev = tsdev_table[i & TSDEV_MINOR_MASK];
-       if (!tsdev || !tsdev->exist)
-               return -ENODEV;
+       if (tsdev)
+               get_device(&tsdev->dev);
+       mutex_unlock(&tsdev_table_mutex);
 
-       get_device(&tsdev->dev);
+       if (!tsdev)
+               return -ENODEV;
 
        client = kzalloc(sizeof(struct tsdev_client), GFP_KERNEL);
        if (!client) {
@@ -173,51 +264,42 @@ static int tsdev_open(struct inode *inode, struct file *file)
                goto err_put_tsdev;
        }
 
+       spin_lock_init(&client->buffer_lock);
        client->tsdev = tsdev;
-       client->raw = (i >= TSDEV_MINORS / 2) ? 1 : 0;
-       list_add_tail(&client->node, &tsdev->client_list);
+       client->raw = i >= TSDEV_MINORS / 2;
+       tsdev_attach_client(tsdev, client);
 
-       if (!tsdev->open++ && tsdev->exist) {
-               error = input_open_device(&tsdev->handle);
-               if (error)
-                       goto err_free_client;
-       }
+       error = tsdev_open_device(tsdev);
+       if (error)
+               goto err_free_client;
 
        file->private_data = client;
        return 0;
 
  err_free_client:
-       list_del(&client->node);
+       tsdev_detach_client(tsdev, client);
        kfree(client);
  err_put_tsdev:
        put_device(&tsdev->dev);
        return error;
 }
 
-static void tsdev_free(struct device *dev)
-{
-       struct tsdev *tsdev = container_of(dev, struct tsdev, dev);
-
-       tsdev_table[tsdev->minor] = NULL;
-       kfree(tsdev);
-}
-
-static int tsdev_release(struct inode *inode, struct file *file)
+static int tsdev_fetch_next_event(struct tsdev_client *client,
+                                 struct ts_event *event)
 {
-       struct tsdev_client *client = file->private_data;
-       struct tsdev *tsdev = client->tsdev;
+       int have_event;
 
-       tsdev_fasync(-1, file, 0);
-
-       list_del(&client->node);
-       kfree(client);
+       spin_lock_irq(&client->buffer_lock);
 
-       if (!--tsdev->open && tsdev->exist)
-               input_close_device(&tsdev->handle);
+       have_event = client->head != client->tail;
+       if (have_event) {
+               *event = client->buffer[client->tail++];
+               client->tail &= TSDEV_BUFFER_SIZE - 1;
+       }
 
-       put_device(&tsdev->dev);
+       spin_unlock_irq(&client->buffer_lock);
 
-       return 0;
+       return have_event;
 }
 
 static ssize_t tsdev_read(struct file *file, char __user *buffer, size_t count,
@@ -225,9 +307,11 @@ static ssize_t tsdev_read(struct file *file, char __user *buffer, size_t count,
 {
        struct tsdev_client *client = file->private_data;
        struct tsdev *tsdev = client->tsdev;
-       int retval = 0;
+       struct ts_event event;
+       int retval;
 
-       if (client->head == client->tail && tsdev->exist && (file->f_flags & O_NONBLOCK))
+       if (client->head == client->tail && tsdev->exist &&
+           (file->f_flags & O_NONBLOCK))
                return -EAGAIN;
 
        retval = wait_event_interruptible(tsdev->wait,
@@ -238,13 +322,14 @@ static ssize_t tsdev_read(struct file *file, char __user *buffer, size_t count,
        if (!tsdev->exist)
                return -ENODEV;
 
-       while (client->head != client->tail &&
-              retval + sizeof (struct ts_event) <= count) {
-               if (copy_to_user (buffer + retval, client->event + client->tail,
-                                 sizeof (struct ts_event)))
+       while (retval + sizeof(struct ts_event) <= count &&
+              tsdev_fetch_next_event(client, &event)) {
+
+               if (copy_to_user(buffer + retval, &event,
+                                sizeof(struct ts_event)))
                        return -EFAULT;
-               client->tail = (client->tail + 1) & (TSDEV_BUFFER_SIZE - 1);
-               retval += sizeof (struct ts_event);
+
+               retval += sizeof(struct ts_event);
        }
 
        return retval;
@@ -261,14 +346,23 @@ static unsigned int tsdev_poll(struct file *file, poll_table *wait)
                (tsdev->exist ? 0 : (POLLHUP | POLLERR));
 }
 
-static int tsdev_ioctl(struct inode *inode, struct file *file,
-                      unsigned int cmd, unsigned long arg)
+static long tsdev_ioctl(struct file *file, unsigned int cmd, unsigned long arg)
 {
        struct tsdev_client *client = file->private_data;
        struct tsdev *tsdev = client->tsdev;
        int retval = 0;
 
+       retval = mutex_lock_interruptible(&tsdev->mutex);
+       if (retval)
+               return retval;
+
+       if (!tsdev->exist) {
+               retval = -ENODEV;
+               goto out;
+       }
+
        switch (cmd) {
+
        case TS_GET_CAL:
                if (copy_to_user((void __user *)arg, &tsdev->cal,
                                 sizeof (struct ts_calibration)))
@@ -277,7 +371,7 @@ static int tsdev_ioctl(struct inode *inode, struct file *file,
 
        case TS_SET_CAL:
                if (copy_from_user(&tsdev->cal, (void __user *)arg,
-                                  sizeof (struct ts_calibration)))
+                                  sizeof(struct ts_calibration)))
                        retval = -EFAULT;
                break;
 
@@ -286,29 +380,79 @@ static int tsdev_ioctl(struct inode *inode, struct file *file,
                break;
        }
 
+ out:
+       mutex_unlock(&tsdev->mutex);
        return retval;
 }
 
 static const struct file_operations tsdev_fops = {
-       .owner =        THIS_MODULE,
-       .open =         tsdev_open,
-       .release =      tsdev_release,
-       .read =         tsdev_read,
-       .poll =         tsdev_poll,
-       .fasync =       tsdev_fasync,
-       .ioctl =        tsdev_ioctl,
+       .owner          = THIS_MODULE,
+       .open           = tsdev_open,
+       .release        = tsdev_release,
+       .read           = tsdev_read,
+       .poll           = tsdev_poll,
+       .fasync         = tsdev_fasync,
+       .unlocked_ioctl = tsdev_ioctl,
 };
 
+static void tsdev_pass_event(struct tsdev *tsdev, struct tsdev_client *client,
+                            int x, int y, int pressure, int millisecs)
+{
+       struct ts_event *event;
+       int tmp;
+
+       /* Interrupts are already disabled, just acquire the lock */
+       spin_lock(&client->buffer_lock);
+
+       event = &client->buffer[client->head++];
+       client->head &= TSDEV_BUFFER_SIZE - 1;
+
+       /* Calibration */
+       if (!client->raw) {
+               x = ((x * tsdev->cal.xscale) >> 8) + tsdev->cal.xtrans;
+               y = ((y * tsdev->cal.yscale) >> 8) + tsdev->cal.ytrans;
+               if (tsdev->cal.xyswap) {
+                       tmp = x; x = y; y = tmp;
+               }
+       }
+
+       event->millisecs = millisecs;
+       event->x = x;
+       event->y = y;
+       event->pressure = pressure;
+
+       spin_unlock(&client->buffer_lock);
+
+       kill_fasync(&client->fasync, SIGIO, POLL_IN);
+}
+
+static void tsdev_distribute_event(struct tsdev *tsdev)
+{
+       struct tsdev_client *client;
+       struct timeval time;
+       int millisecs;
+
+       do_gettimeofday(&time);
+       millisecs = time.tv_usec / 1000;
+
+       list_for_each_entry_rcu(client, &tsdev->client_list, node)
+               tsdev_pass_event(tsdev, client,
+                                tsdev->x, tsdev->y,
+                                tsdev->pressure, millisecs);
+}
+
 static void tsdev_event(struct input_handle *handle, unsigned int type,
                        unsigned int code, int value)
 {
        struct tsdev *tsdev = handle->private;
-       struct tsdev_client *client;
-       struct timeval time;
+       struct input_dev *dev = handle->dev;
+       int wake_up_readers = 0;
 
        switch (type) {
+
        case EV_ABS:
                switch (code) {
+
                case ABS_X:
                        tsdev->x = value;
                        break;
@@ -318,9 +462,9 @@ static void tsdev_event(struct input_handle *handle, unsigned int type,
                        break;
 
                case ABS_PRESSURE:
-                       if (value > handle->dev->absmax[ABS_PRESSURE])
-                               value = handle->dev->absmax[ABS_PRESSURE];
-                       value -= handle->dev->absmin[ABS_PRESSURE];
+                       if (value > dev->absmax[ABS_PRESSURE])
+                               value = dev->absmax[ABS_PRESSURE];
+                       value -= dev->absmin[ABS_PRESSURE];
                        if (value < 0)
                                value = 0;
                        tsdev->pressure = value;
@@ -330,6 +474,7 @@ static void tsdev_event(struct input_handle *handle, unsigned int type,
 
        case EV_REL:
                switch (code) {
+
                case REL_X:
                        tsdev->x += value;
                        if (tsdev->x < 0)
@@ -351,6 +496,7 @@ static void tsdev_event(struct input_handle *handle, unsigned int type,
        case EV_KEY:
                if (code == BTN_TOUCH || code == BTN_MOUSE) {
                        switch (value) {
+
                        case 0:
                                tsdev->pressure = 0;
                                break;
@@ -362,49 +508,71 @@ static void tsdev_event(struct input_handle *handle, unsigned int type,
                        }
                }
                break;
+
+       case EV_SYN:
+               if (code == SYN_REPORT) {
+                       tsdev_distribute_event(tsdev);
+                       wake_up_readers = 1;
+               }
+               break;
        }
 
-       if (type != EV_SYN || code != SYN_REPORT)
-               return;
+       if (wake_up_readers)
+               wake_up_interruptible(&tsdev->wait);
+}
+
+static int tsdev_install_chrdev(struct tsdev *tsdev)
+{
+       tsdev_table[tsdev->minor] = tsdev;
+       return 0;
+}
 
-       list_for_each_entry(client, &tsdev->client_list, node) {
-               int x, y, tmp;
+static void tsdev_remove_chrdev(struct tsdev *tsdev)
+{
+       mutex_lock(&tsdev_table_mutex);
+       tsdev_table[tsdev->minor] = NULL;
+       mutex_unlock(&tsdev_table_mutex);
+}
 
-               do_gettimeofday(&time);
-               client->event[client->head].millisecs = time.tv_usec / 1000;
-               client->event[client->head].pressure = tsdev->pressure;
+/*
+ * Mark device non-existant. This disables writes, ioctls and
+ * prevents new users from opening the device. Already posted
+ * blocking reads will stay, however new ones will fail.
+ */
+static void tsdev_mark_dead(struct tsdev *tsdev)
+{
+       mutex_lock(&tsdev->mutex);
+       tsdev->exist = 0;
+       mutex_unlock(&tsdev->mutex);
+}
 
-               x = tsdev->x;
-               y = tsdev->y;
+static void tsdev_cleanup(struct tsdev *tsdev)
+{
+       struct input_handle *handle = &tsdev->handle;
 
-               /* Calibration */
-               if (!client->raw) {
-                       x = ((x * tsdev->cal.xscale) >> 8) + tsdev->cal.xtrans;
-                       y = ((y * tsdev->cal.yscale) >> 8) + tsdev->cal.ytrans;
-                       if (tsdev->cal.xyswap) {
-                               tmp = x; x = y; y = tmp;
-                       }
-               }
+       tsdev_mark_dead(tsdev);
+       tsdev_hangup(tsdev);
+       tsdev_remove_chrdev(tsdev);
 
-               client->event[client->head].x = x;
-               client->event[client->head].y = y;
-               client->head = (client->head + 1) & (TSDEV_BUFFER_SIZE - 1);
-               kill_fasync(&client->fasync, SIGIO, POLL_IN);
-       }
-       wake_up_interruptible(&tsdev->wait);
+       /* tsdev is marked dead so noone else accesses tsdev->open */
+       if (tsdev->open)
+               input_close_device(handle);
 }
 
 static int tsdev_connect(struct input_handler *handler, struct input_dev *dev,
                         const struct input_device_id *id)
 {
        struct tsdev *tsdev;
-       int minor, delta;
+       int delta;
+       int minor;
        int error;
 
-       for (minor = 0; minor < TSDEV_MINORS / 2 && tsdev_table[minor]; minor++);
-       if (minor >= TSDEV_MINORS / 2) {
-               printk(KERN_ERR
-                      "tsdev: You have way too many touchscreens\n");
+       for (minor = 0; minor < TSDEV_MINORS / 2; minor++)
+               if (!tsdev_table[minor])
+                       break;
+
+       if (minor == TSDEV_MINORS) {
+               printk(KERN_ERR "tsdev: no more free tsdev devices\n");
                return -ENFILE;
        }
 
@@ -413,15 +581,18 @@ static int tsdev_connect(struct input_handler *handler, struct input_dev *dev,
                return -ENOMEM;
 
        INIT_LIST_HEAD(&tsdev->client_list);
+       spin_lock_init(&tsdev->client_lock);
+       mutex_init(&tsdev->mutex);
        init_waitqueue_head(&tsdev->wait);
 
+       snprintf(tsdev->name, sizeof(tsdev->name), "ts%d", minor);
        tsdev->exist = 1;
        tsdev->minor = minor;
+
        tsdev->handle.dev = dev;
        tsdev->handle.name = tsdev->name;
        tsdev->handle.handler = handler;
        tsdev->handle.private = tsdev;
-       snprintf(tsdev->name, sizeof(tsdev->name), "ts%d", minor);
 
        /* Precompute the rough calibration matrix */
        delta = dev->absmax [ABS_X] - dev->absmin [ABS_X] + 1;
@@ -436,28 +607,31 @@ static int tsdev_connect(struct input_handler *handler, struct input_dev *dev,
        tsdev->cal.yscale = (yres << 8) / delta;
        tsdev->cal.ytrans = - ((dev->absmin [ABS_Y] * tsdev->cal.yscale) >> 8);
 
-       snprintf(tsdev->dev.bus_id, sizeof(tsdev->dev.bus_id),
-                "ts%d", minor);
+       strlcpy(tsdev->dev.bus_id, tsdev->name, sizeof(tsdev->dev.bus_id));
+       tsdev->dev.devt = MKDEV(INPUT_MAJOR, TSDEV_MINOR_BASE + minor);
        tsdev->dev.class = &input_class;
        tsdev->dev.parent = &dev->dev;
-       tsdev->dev.devt = MKDEV(INPUT_MAJOR, TSDEV_MINOR_BASE + minor);
        tsdev->dev.release = tsdev_free;
        device_initialize(&tsdev->dev);
 
-       tsdev_table[minor] = tsdev;
-
-       error = device_add(&tsdev->dev);
+       error = input_register_handle(&tsdev->handle);
        if (error)
                goto err_free_tsdev;
 
-       error = input_register_handle(&tsdev->handle);
+       error = tsdev_install_chrdev(tsdev);
        if (error)
-               goto err_delete_tsdev;
+               goto err_unregister_handle;
+
+       error = device_add(&tsdev->dev);
+       if (error)
+               goto err_cleanup_tsdev;
 
        return 0;
 
- err_delete_tsdev:
-       device_del(&tsdev->dev);
+ err_cleanup_tsdev:
+       tsdev_cleanup(tsdev);
+ err_unregister_handle:
+       input_unregister_handle(&tsdev->handle);
  err_free_tsdev:
        put_device(&tsdev->dev);
        return error;
@@ -466,20 +640,10 @@ static int tsdev_connect(struct input_handler *handler, struct input_dev *dev,
 static void tsdev_disconnect(struct input_handle *handle)
 {
        struct tsdev *tsdev = handle->private;
-       struct tsdev_client *client;
 
-       input_unregister_handle(handle);
        device_del(&tsdev->dev);
-
-       tsdev->exist = 0;
-
-       if (tsdev->open) {
-               input_close_device(handle);
-               list_for_each_entry(client, &tsdev->client_list, node)
-                       kill_fasync(&client->fasync, SIGIO, POLL_HUP);
-               wake_up_interruptible(&tsdev->wait);
-       }
-
+       tsdev_cleanup(tsdev);
+       input_unregister_handle(handle);
        put_device(&tsdev->dev);
 }
 
@@ -510,13 +674,13 @@ static const struct input_device_id tsdev_ids[] = {
 MODULE_DEVICE_TABLE(input, tsdev_ids);
 
 static struct input_handler tsdev_handler = {
-       .event =        tsdev_event,
-       .connect =      tsdev_connect,
-       .disconnect =   tsdev_disconnect,
-       .fops =         &tsdev_fops,
-       .minor =        TSDEV_MINOR_BASE,
-       .name =         "tsdev",
-       .id_table =     tsdev_ids,
+       .event          = tsdev_event,
+       .connect        = tsdev_connect,
+       .disconnect     = tsdev_disconnect,
+       .fops           = &tsdev_fops,
+       .minor          = TSDEV_MINOR_BASE,
+       .name           = "tsdev",
+       .id_table       = tsdev_ids,
 };
 
 static int __init tsdev_init(void)