net/packet: check length in getsockopt() called with PACKET_HDRLEN
[GitHub/mt8127/android_kernel_alcatel_ttab.git] / net / packet / af_packet.c
index 20a1bd0e6549dcd5c55592c9138c4e99556f3816..b915d0112874b8d1f8c0ac8a3f3a38cfc58ab6e5 100644 (file)
@@ -237,6 +237,30 @@ struct packet_skb_cb {
 static void __fanout_unlink(struct sock *sk, struct packet_sock *po);
 static void __fanout_link(struct sock *sk, struct packet_sock *po);
 
+static struct net_device *packet_cached_dev_get(struct packet_sock *po)
+{
+       struct net_device *dev;
+
+       rcu_read_lock();
+       dev = rcu_dereference(po->cached_dev);
+       if (likely(dev))
+               dev_hold(dev);
+       rcu_read_unlock();
+
+       return dev;
+}
+
+static void packet_cached_dev_assign(struct packet_sock *po,
+                                    struct net_device *dev)
+{
+       rcu_assign_pointer(po->cached_dev, dev);
+}
+
+static void packet_cached_dev_reset(struct packet_sock *po)
+{
+       RCU_INIT_POINTER(po->cached_dev, NULL);
+}
+
 /* register_prot_hook must be invoked with the po->bind_lock held,
  * or from a context in which asynchronous accesses to the packet
  * socket is not possible (packet_create()).
@@ -244,11 +268,13 @@ static void __fanout_link(struct sock *sk, struct packet_sock *po);
 static void register_prot_hook(struct sock *sk)
 {
        struct packet_sock *po = pkt_sk(sk);
+
        if (!po->running) {
                if (po->fanout)
                        __fanout_link(sk, po);
                else
                        dev_add_pack(&po->prot_hook);
+
                sock_hold(sk);
                po->running = 1;
        }
@@ -266,10 +292,12 @@ static void __unregister_prot_hook(struct sock *sk, bool sync)
        struct packet_sock *po = pkt_sk(sk);
 
        po->running = 0;
+
        if (po->fanout)
                __fanout_unlink(sk, po);
        else
                __dev_remove_pack(&po->prot_hook);
+
        __sock_put(sk);
 
        if (sync) {
@@ -432,9 +460,9 @@ static void prb_shutdown_retire_blk_timer(struct packet_sock *po,
 
        pkc = tx_ring ? &po->tx_ring.prb_bdqc : &po->rx_ring.prb_bdqc;
 
-       spin_lock(&rb_queue->lock);
+       spin_lock_bh(&rb_queue->lock);
        pkc->delete_blk_timer = 1;
-       spin_unlock(&rb_queue->lock);
+       spin_unlock_bh(&rb_queue->lock);
 
        prb_del_retire_blk_timer(pkc);
 }
@@ -537,6 +565,7 @@ static void init_prb_bdqc(struct packet_sock *po,
        p1->tov_in_jiffies = msecs_to_jiffies(p1->retire_blk_tov);
        p1->blk_sizeof_priv = req_u->req3.tp_sizeof_priv;
 
+       p1->max_frame_len = p1->kblk_size - BLK_PLUS_PRIV(p1->blk_sizeof_priv);
        prb_init_ft_ops(p1, req_u);
        prb_setup_retire_blk_timer(po, tx_ring);
        prb_open_block(p1, pbd);
@@ -1121,16 +1150,6 @@ static void packet_sock_destruct(struct sock *sk)
        sk_refcnt_debug_dec(sk);
 }
 
-static int fanout_rr_next(struct packet_fanout *f, unsigned int num)
-{
-       int x = atomic_read(&f->rr_cur) + 1;
-
-       if (x >= num)
-               x = 0;
-
-       return x;
-}
-
 static unsigned int fanout_demux_hash(struct packet_fanout *f,
                                      struct sk_buff *skb,
                                      unsigned int num)
@@ -1142,13 +1161,9 @@ static unsigned int fanout_demux_lb(struct packet_fanout *f,
                                    struct sk_buff *skb,
                                    unsigned int num)
 {
-       int cur, old;
+       unsigned int val = atomic_inc_return(&f->rr_cur);
 
-       cur = atomic_read(&f->rr_cur);
-       while ((old = atomic_cmpxchg(&f->rr_cur, cur,
-                                    fanout_rr_next(f, num))) != cur)
-               cur = old;
-       return cur;
+       return val % num;
 }
 
 static unsigned int fanout_demux_cpu(struct packet_fanout *f,
@@ -1188,7 +1203,7 @@ static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev,
                             struct packet_type *pt, struct net_device *orig_dev)
 {
        struct packet_fanout *f = pt->af_packet_priv;
-       unsigned int num = f->num_members;
+       unsigned int num = ACCESS_ONCE(f->num_members);
        struct packet_sock *po;
        unsigned int idx;
 
@@ -1242,6 +1257,8 @@ static void __fanout_link(struct sock *sk, struct packet_sock *po)
        f->arr[f->num_members] = sk;
        smp_wmb();
        f->num_members++;
+       if (f->num_members == 1)
+               dev_add_pack(&f->prot_hook);
        spin_unlock(&f->lock);
 }
 
@@ -1258,6 +1275,8 @@ static void __fanout_unlink(struct sock *sk, struct packet_sock *po)
        BUG_ON(i >= f->num_members);
        f->arr[i] = f->arr[f->num_members - 1];
        f->num_members--;
+       if (f->num_members == 0)
+               __dev_remove_pack(&f->prot_hook);
        spin_unlock(&f->lock);
 }
 
@@ -1289,13 +1308,16 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
                return -EINVAL;
        }
 
+       mutex_lock(&fanout_mutex);
+
+       err = -EINVAL;
        if (!po->running)
-               return -EINVAL;
+               goto out;
 
+       err = -EALREADY;
        if (po->fanout)
-               return -EALREADY;
+               goto out;
 
-       mutex_lock(&fanout_mutex);
        match = NULL;
        list_for_each_entry(f, &fanout_list, list) {
                if (f->id == id &&
@@ -1325,7 +1347,6 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
                match->prot_hook.func = packet_rcv_fanout;
                match->prot_hook.af_packet_priv = match;
                match->prot_hook.id_match = match_fanout_group;
-               dev_add_pack(&match->prot_hook);
                list_add(&match->list, &fanout_list);
        }
        err = -EINVAL;
@@ -1346,24 +1367,29 @@ out:
        return err;
 }
 
-static void fanout_release(struct sock *sk)
+/* If pkt_sk(sk)->fanout->sk_ref is zero, this function removes
+ * pkt_sk(sk)->fanout from fanout_list and returns pkt_sk(sk)->fanout.
+ * It is the responsibility of the caller to call fanout_release_data() and
+ * free the returned packet_fanout (after synchronize_net())
+ */
+static struct packet_fanout *fanout_release(struct sock *sk)
 {
        struct packet_sock *po = pkt_sk(sk);
        struct packet_fanout *f;
 
-       f = po->fanout;
-       if (!f)
-               return;
-
        mutex_lock(&fanout_mutex);
-       po->fanout = NULL;
+       f = po->fanout;
+       if (f) {
+               po->fanout = NULL;
 
-       if (atomic_dec_and_test(&f->sk_ref)) {
-               list_del(&f->list);
-               dev_remove_pack(&f->prot_hook);
-               kfree(f);
+               if (atomic_dec_and_test(&f->sk_ref))
+                       list_del(&f->list);
+               else
+                       f = NULL;
        }
        mutex_unlock(&fanout_mutex);
+
+       return f;
 }
 
 static const struct proto_ops packet_ops;
@@ -1775,6 +1801,18 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                        if ((int)snaplen < 0)
                                snaplen = 0;
                }
+       } else if (unlikely(macoff + snaplen >
+                           GET_PBDQC_FROM_RB(&po->rx_ring)->max_frame_len)) {
+               u32 nval;
+
+               nval = GET_PBDQC_FROM_RB(&po->rx_ring)->max_frame_len - macoff;
+               pr_err_once("tpacket_rcv: packet too big, clamped from %u to %u. macoff=%u\n",
+                           snaplen, nval, macoff);
+               snaplen = nval;
+               if (unlikely((int)snaplen < 0)) {
+                       snaplen = 0;
+                       macoff = GET_PBDQC_FROM_RB(&po->rx_ring)->max_frame_len;
+               }
        }
        spin_lock(&sk->sk_receive_queue.lock);
        h.raw = packet_current_rx_frame(po, skb,
@@ -2046,7 +2084,6 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
        struct sk_buff *skb;
        struct net_device *dev;
        __be16 proto;
-       bool need_rls_dev = false;
        int err, reserve = 0;
        void *ph;
        struct sockaddr_ll *saddr = (struct sockaddr_ll *)msg->msg_name;
@@ -2058,8 +2095,8 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
 
        mutex_lock(&po->pg_vec_lock);
 
-       if (saddr == NULL) {
-               dev = po->prot_hook.dev;
+       if (likely(saddr == NULL)) {
+               dev     = packet_cached_dev_get(po);
                proto   = po->num;
                addr    = NULL;
        } else {
@@ -2073,19 +2110,17 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
                proto   = saddr->sll_protocol;
                addr    = saddr->sll_addr;
                dev = dev_get_by_index(sock_net(&po->sk), saddr->sll_ifindex);
-               need_rls_dev = true;
        }
 
        err = -ENXIO;
        if (unlikely(dev == NULL))
                goto out;
-
-       reserve = dev->hard_header_len;
-
        err = -ENETDOWN;
        if (unlikely(!(dev->flags & IFF_UP)))
                goto out_put;
 
+       reserve = dev->hard_header_len;
+
        size_max = po->tx_ring.frame_size
                - (po->tp_hdrlen - sizeof(struct sockaddr_ll));
 
@@ -2162,8 +2197,7 @@ out_status:
        __packet_set_status(po, ph, status);
        kfree_skb(skb);
 out_put:
-       if (need_rls_dev)
-               dev_put(dev);
+       dev_put(dev);
 out:
        mutex_unlock(&po->pg_vec_lock);
        return err;
@@ -2201,7 +2235,6 @@ static int packet_snd(struct socket *sock,
        struct sk_buff *skb;
        struct net_device *dev;
        __be16 proto;
-       bool need_rls_dev = false;
        unsigned char *addr;
        int err, reserve = 0;
        struct virtio_net_hdr vnet_hdr = { 0 };
@@ -2209,15 +2242,15 @@ static int packet_snd(struct socket *sock,
        int vnet_hdr_len;
        struct packet_sock *po = pkt_sk(sk);
        unsigned short gso_type = 0;
-       int hlen, tlen;
+       int hlen, tlen, linear;
        int extra_len = 0;
 
        /*
         *      Get and verify the address.
         */
 
-       if (saddr == NULL) {
-               dev = po->prot_hook.dev;
+       if (likely(saddr == NULL)) {
+               dev     = packet_cached_dev_get(po);
                proto   = po->num;
                addr    = NULL;
        } else {
@@ -2229,19 +2262,17 @@ static int packet_snd(struct socket *sock,
                proto   = saddr->sll_protocol;
                addr    = saddr->sll_addr;
                dev = dev_get_by_index(sock_net(sk), saddr->sll_ifindex);
-               need_rls_dev = true;
        }
 
        err = -ENXIO;
-       if (dev == NULL)
+       if (unlikely(dev == NULL))
                goto out_unlock;
-       if (sock->type == SOCK_RAW)
-               reserve = dev->hard_header_len;
-
        err = -ENETDOWN;
-       if (!(dev->flags & IFF_UP))
+       if (unlikely(!(dev->flags & IFF_UP)))
                goto out_unlock;
 
+       if (sock->type == SOCK_RAW)
+               reserve = dev->hard_header_len;
        if (po->has_vnet_hdr) {
                vnet_hdr_len = sizeof(vnet_hdr);
 
@@ -2305,7 +2336,9 @@ static int packet_snd(struct socket *sock,
        err = -ENOBUFS;
        hlen = LL_RESERVED_SPACE(dev);
        tlen = dev->needed_tailroom;
-       skb = packet_alloc_skb(sk, hlen + tlen, hlen, len, vnet_hdr.hdr_len,
+       linear = vnet_hdr.hdr_len;
+       linear = max(linear, min_t(int, len, dev->hard_header_len));
+       skb = packet_alloc_skb(sk, hlen + tlen, hlen, len, linear,
                               msg->msg_flags & MSG_DONTWAIT, &err);
        if (skb == NULL)
                goto out_unlock;
@@ -2375,15 +2408,14 @@ static int packet_snd(struct socket *sock,
        if (err > 0 && (err = net_xmit_errno(err)) != 0)
                goto out_unlock;
 
-       if (need_rls_dev)
-               dev_put(dev);
+       dev_put(dev);
 
        return len;
 
 out_free:
        kfree_skb(skb);
 out_unlock:
-       if (dev && need_rls_dev)
+       if (dev)
                dev_put(dev);
 out:
        return err;
@@ -2409,6 +2441,7 @@ static int packet_release(struct socket *sock)
 {
        struct sock *sk = sock->sk;
        struct packet_sock *po;
+       struct packet_fanout *f;
        struct net *net;
        union tpacket_req_u req_u;
 
@@ -2428,6 +2461,8 @@ static int packet_release(struct socket *sock)
 
        spin_lock(&po->bind_lock);
        unregister_prot_hook(sk, false);
+       packet_cached_dev_reset(po);
+
        if (po->prot_hook.dev) {
                dev_put(po->prot_hook.dev);
                po->prot_hook.dev = NULL;
@@ -2446,9 +2481,13 @@ static int packet_release(struct socket *sock)
                packet_set_ring(sk, &req_u, 1, 1);
        }
 
-       fanout_release(sk);
+       f = fanout_release(sk);
 
        synchronize_net();
+
+       if (f) {
+               kfree(f);
+       }
        /*
         *      Now the socket is dead. No more input will appear.
         */
@@ -2483,14 +2522,17 @@ static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protoc
 
        spin_lock(&po->bind_lock);
        unregister_prot_hook(sk, true);
+
        po->num = protocol;
        po->prot_hook.type = protocol;
        if (po->prot_hook.dev)
                dev_put(po->prot_hook.dev);
-       po->prot_hook.dev = dev;
 
+       po->prot_hook.dev = dev;
        po->ifindex = dev ? dev->ifindex : 0;
 
+       packet_cached_dev_assign(po, dev);
+
        if (protocol == 0)
                goto out_unlock;
 
@@ -2516,7 +2558,7 @@ static int packet_bind_spkt(struct socket *sock, struct sockaddr *uaddr,
                            int addr_len)
 {
        struct sock *sk = sock->sk;
-       char name[15];
+       char name[sizeof(uaddr->sa_data) + 1];
        struct net_device *dev;
        int err = -ENODEV;
 
@@ -2526,7 +2568,11 @@ static int packet_bind_spkt(struct socket *sock, struct sockaddr *uaddr,
 
        if (addr_len != sizeof(struct sockaddr))
                return -EINVAL;
-       strlcpy(name, uaddr->sa_data, sizeof(name));
+       /* uaddr->sa_data comes from the userspace, it's not guaranteed to be
+        * zero-terminated.
+        */
+       memcpy(name, uaddr->sa_data, sizeof(uaddr->sa_data));
+       name[sizeof(uaddr->sa_data)] = 0;
 
        dev = dev_get_by_name(sock_net(sk), name);
        if (dev)
@@ -2604,6 +2650,8 @@ static int packet_create(struct net *net, struct socket *sock, int protocol,
        sk->sk_family = PF_PACKET;
        po->num = proto;
 
+       packet_cached_dev_reset(po);
+
        sk->sk_destruct = packet_sock_destruct;
        sk_refcnt_debug_inc(sk);
 
@@ -2694,7 +2742,6 @@ static int packet_recvmsg(struct kiocb *iocb, struct socket *sock,
        struct sock *sk = sock->sk;
        struct sk_buff *skb;
        int copied, err;
-       struct sockaddr_ll *sll;
        int vnet_hdr_len = 0;
 
        err = -EINVAL;
@@ -2777,22 +2824,10 @@ static int packet_recvmsg(struct kiocb *iocb, struct socket *sock,
                        goto out_free;
        }
 
-       /*
-        *      If the address length field is there to be filled in, we fill
-        *      it in now.
-        */
-
-       sll = &PACKET_SKB_CB(skb)->sa.ll;
-       if (sock->type == SOCK_PACKET)
-               msg->msg_namelen = sizeof(struct sockaddr_pkt);
-       else
-               msg->msg_namelen = sll->sll_halen + offsetof(struct sockaddr_ll, sll_addr);
-
-       /*
-        *      You lose any data beyond the buffer you gave. If it worries a
-        *      user program they can ask the device for its MTU anyway.
+       /* You lose any data beyond the buffer you gave. If it worries
+        * a user program they can ask the device for its MTU
+        * anyway.
         */
-
        copied = skb->len;
        if (copied > len) {
                copied = len;
@@ -2805,9 +2840,20 @@ static int packet_recvmsg(struct kiocb *iocb, struct socket *sock,
 
        sock_recv_ts_and_drops(msg, sk, skb);
 
-       if (msg->msg_name)
+       if (msg->msg_name) {
+               /* If the address length field is there to be filled
+                * in, we fill it in now.
+                */
+               if (sock->type == SOCK_PACKET) {
+                       msg->msg_namelen = sizeof(struct sockaddr_pkt);
+               } else {
+                       struct sockaddr_ll *sll = &PACKET_SKB_CB(skb)->sa.ll;
+                       msg->msg_namelen = sll->sll_halen +
+                               offsetof(struct sockaddr_ll, sll_addr);
+               }
                memcpy(msg->msg_name, &PACKET_SKB_CB(skb)->sa,
                       msg->msg_namelen);
+       }
 
        if (pkt_sk(sk)->auxdata) {
                struct tpacket_auxdata aux;
@@ -2973,6 +3019,7 @@ static int packet_mc_add(struct sock *sk, struct packet_mreq_max *mreq)
        i->ifindex = mreq->mr_ifindex;
        i->alen = mreq->mr_alen;
        memcpy(i->addr, mreq->mr_address, i->alen);
+       memset(i->addr + i->alen, 0, sizeof(i->addr) - i->alen);
        i->count = 1;
        i->next = po->mclist;
        po->mclist = i;
@@ -3110,19 +3157,25 @@ packet_setsockopt(struct socket *sock, int level, int optname, char __user *optv
 
                if (optlen != sizeof(val))
                        return -EINVAL;
-               if (po->rx_ring.pg_vec || po->tx_ring.pg_vec)
-                       return -EBUSY;
                if (copy_from_user(&val, optval, sizeof(val)))
                        return -EFAULT;
                switch (val) {
                case TPACKET_V1:
                case TPACKET_V2:
                case TPACKET_V3:
-                       po->tp_version = val;
-                       return 0;
+                       break;
                default:
                        return -EINVAL;
                }
+               lock_sock(sk);
+               if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) {
+                       ret = -EBUSY;
+               } else {
+                       po->tp_version = val;
+                       ret = 0;
+               }
+               release_sock(sk);
+               return ret;
        }
        case PACKET_RESERVE:
        {
@@ -3134,6 +3187,8 @@ packet_setsockopt(struct socket *sock, int level, int optname, char __user *optv
                        return -EBUSY;
                if (copy_from_user(&val, optval, sizeof(val)))
                        return -EFAULT;
+               if (val > INT_MAX)
+                       return -EINVAL;
                po->tp_reserve = val;
                return 0;
        }
@@ -3259,9 +3314,11 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
 
                if (po->tp_version == TPACKET_V3) {
                        lv = sizeof(struct tpacket_stats_v3);
+                       st.stats3.tp_packets += st.stats3.tp_drops;
                        data = &st.stats3;
                } else {
                        lv = sizeof(struct tpacket_stats);
+                       st.stats1.tp_packets += st.stats1.tp_drops;
                        data = &st.stats1;
                }
 
@@ -3281,6 +3338,8 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
        case PACKET_HDRLEN:
                if (len > sizeof(int))
                        len = sizeof(int);
+               if (len < sizeof(int))
+                       return -EINVAL;
                if (copy_from_user(&val, optval, len))
                        return -EFAULT;
                switch (val) {
@@ -3356,6 +3415,7 @@ static int packet_notifier(struct notifier_block *this, unsigned long msg, void
                                                sk->sk_error_report(sk);
                                }
                                if (msg == NETDEV_UNREGISTER) {
+                                       packet_cached_dev_reset(po);
                                        po->ifindex = -1;
                                        if (po->prot_hook.dev)
                                                dev_put(po->prot_hook.dev);
@@ -3574,6 +3634,7 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u,
        /* Added to avoid minimal code churn */
        struct tpacket_req *req = &req_u->req;
 
+       lock_sock(sk);
        /* Opening a Tx-ring is NOT supported in TPACKET_V3 */
        if (!closing && tx_ring && (po->tp_version > TPACKET_V2)) {
                WARN(1, "Tx-ring is not supported.\n");
@@ -3614,6 +3675,10 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u,
                        goto out;
                if (unlikely(req->tp_block_size & (PAGE_SIZE - 1)))
                        goto out;
+               if (po->tp_version >= TPACKET_V3 &&
+                   req->tp_block_size <=
+                         BLK_PLUS_PRIV((u64)req_u->req3.tp_sizeof_priv))
+                       goto out;
                if (unlikely(req->tp_frame_size < po->tp_hdrlen +
                                        po->tp_reserve))
                        goto out;
@@ -3623,6 +3688,8 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u,
                rb->frames_per_block = req->tp_block_size/req->tp_frame_size;
                if (unlikely(rb->frames_per_block <= 0))
                        goto out;
+               if (unlikely(req->tp_block_size > UINT_MAX / req->tp_block_nr))
+                       goto out;
                if (unlikely((rb->frames_per_block * req->tp_block_nr) !=
                                        req->tp_frame_nr))
                        goto out;
@@ -3639,7 +3706,7 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u,
                 */
                        if (!tx_ring)
                                init_prb_bdqc(po, rb, pg_vec, req_u, tx_ring);
-                               break;
+                       break;
                default:
                        break;
                }
@@ -3651,7 +3718,6 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u,
                        goto out;
        }
 
-       lock_sock(sk);
 
        /* Detach socket from network */
        spin_lock(&po->bind_lock);
@@ -3700,11 +3766,11 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u,
                if (!tx_ring)
                        prb_shutdown_retire_blk_timer(po, tx_ring, rb_queue);
        }
-       release_sock(sk);
 
        if (pg_vec)
                free_pg_vec(pg_vec, order, req->tp_block_nr);
 out:
+       release_sock(sk);
        return err;
 }