net: ipv4: add second dif to udp socket lookups
authorDavid Ahern <dsahern@gmail.com>
Mon, 7 Aug 2017 15:44:16 +0000 (08:44 -0700)
committerDavid S. Miller <davem@davemloft.net>
Mon, 7 Aug 2017 18:39:21 +0000 (11:39 -0700)
Add a second device index, sdif, to udp socket lookups. sdif is the
index for ingress devices enslaved to an l3mdev. It allows the lookups
to consider the enslaved device as well as the L3 domain when searching
for a socket.

Early demux lookups are handled in the next patch as part of INET_MATCH
changes.

Signed-off-by: David Ahern <dsahern@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/ip.h
include/net/udp.h
net/ipv4/udp.c
net/ipv4/udp_diag.c

index 9e59dcf1787a88a06153dab81584cdafb9c8b8ba..39db596eb89fc346c549945482582f0abc89b6f0 100644 (file)
@@ -78,6 +78,16 @@ struct ipcm_cookie {
 #define IPCB(skb) ((struct inet_skb_parm*)((skb)->cb))
 #define PKTINFO_SKB_CB(skb) ((struct in_pktinfo *)((skb)->cb))
 
+/* return enslaved device index if relevant */
+static inline int inet_sdif(struct sk_buff *skb)
+{
+#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
+       if (skb && ipv4_l3mdev_skb(IPCB(skb)->flags))
+               return IPCB(skb)->iif;
+#endif
+       return 0;
+}
+
 struct ip_ra_chain {
        struct ip_ra_chain __rcu *next;
        struct sock             *sk;
index cc8036987dcb885012c6c5eda0fb2bed2e588841..826c713d5a4858dc90409fc1da4a3b9de698a4ce 100644 (file)
@@ -287,7 +287,7 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
 struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
                             __be32 daddr, __be16 dport, int dif);
 struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
-                              __be32 daddr, __be16 dport, int dif,
+                              __be32 daddr, __be16 dport, int dif, int sdif,
                               struct udp_table *tbl, struct sk_buff *skb);
 struct sock *udp4_lib_lookup_skb(struct sk_buff *skb,
                                 __be16 sport, __be16 dport);
index 38bca2c4897d6cd5ed5ed92475d4715ad627a168..fe14429e4a6c834dd66d499ba3f2cab298bb5e6b 100644 (file)
@@ -380,8 +380,8 @@ int udp_v4_get_port(struct sock *sk, unsigned short snum)
 
 static int compute_score(struct sock *sk, struct net *net,
                         __be32 saddr, __be16 sport,
-                        __be32 daddr, unsigned short hnum, int dif,
-                        bool exact_dif)
+                        __be32 daddr, unsigned short hnum,
+                        int dif, int sdif, bool exact_dif)
 {
        int score;
        struct inet_sock *inet;
@@ -413,10 +413,15 @@ static int compute_score(struct sock *sk, struct net *net,
        }
 
        if (sk->sk_bound_dev_if || exact_dif) {
-               if (sk->sk_bound_dev_if != dif)
+               bool dev_match = (sk->sk_bound_dev_if == dif ||
+                                 sk->sk_bound_dev_if == sdif);
+
+               if (exact_dif && !dev_match)
                        return -1;
-               score += 4;
+               if (sk->sk_bound_dev_if && dev_match)
+                       score += 4;
        }
+
        if (sk->sk_incoming_cpu == raw_smp_processor_id())
                score++;
        return score;
@@ -436,10 +441,11 @@ static u32 udp_ehashfn(const struct net *net, const __be32 laddr,
 
 /* called with rcu_read_lock() */
 static struct sock *udp4_lib_lookup2(struct net *net,
-               __be32 saddr, __be16 sport,
-               __be32 daddr, unsigned int hnum, int dif, bool exact_dif,
-               struct udp_hslot *hslot2,
-               struct sk_buff *skb)
+                                    __be32 saddr, __be16 sport,
+                                    __be32 daddr, unsigned int hnum,
+                                    int dif, int sdif, bool exact_dif,
+                                    struct udp_hslot *hslot2,
+                                    struct sk_buff *skb)
 {
        struct sock *sk, *result;
        int score, badness, matches = 0, reuseport = 0;
@@ -449,7 +455,7 @@ static struct sock *udp4_lib_lookup2(struct net *net,
        badness = 0;
        udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
                score = compute_score(sk, net, saddr, sport,
-                                     daddr, hnum, dif, exact_dif);
+                                     daddr, hnum, dif, sdif, exact_dif);
                if (score > badness) {
                        reuseport = sk->sk_reuseport;
                        if (reuseport) {
@@ -477,8 +483,8 @@ static struct sock *udp4_lib_lookup2(struct net *net,
  * harder than this. -DaveM
  */
 struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
-               __be16 sport, __be32 daddr, __be16 dport,
-               int dif, struct udp_table *udptable, struct sk_buff *skb)
+               __be16 sport, __be32 daddr, __be16 dport, int dif,
+               int sdif, struct udp_table *udptable, struct sk_buff *skb)
 {
        struct sock *sk, *result;
        unsigned short hnum = ntohs(dport);
@@ -496,7 +502,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
                        goto begin;
 
                result = udp4_lib_lookup2(net, saddr, sport,
-                                         daddr, hnum, dif,
+                                         daddr, hnum, dif, sdif,
                                          exact_dif, hslot2, skb);
                if (!result) {
                        unsigned int old_slot2 = slot2;
@@ -511,7 +517,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
                                goto begin;
 
                        result = udp4_lib_lookup2(net, saddr, sport,
-                                                 daddr, hnum, dif,
+                                                 daddr, hnum, dif, sdif,
                                                  exact_dif, hslot2, skb);
                }
                return result;
@@ -521,7 +527,7 @@ begin:
        badness = 0;
        sk_for_each_rcu(sk, &hslot->head) {
                score = compute_score(sk, net, saddr, sport,
-                                     daddr, hnum, dif, exact_dif);
+                                     daddr, hnum, dif, sdif, exact_dif);
                if (score > badness) {
                        reuseport = sk->sk_reuseport;
                        if (reuseport) {
@@ -554,7 +560,7 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
 
        return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport,
                                 iph->daddr, dport, inet_iif(skb),
-                                udptable, skb);
+                                inet_sdif(skb), udptable, skb);
 }
 
 struct sock *udp4_lib_lookup_skb(struct sk_buff *skb,
@@ -576,7 +582,7 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
        struct sock *sk;
 
        sk = __udp4_lib_lookup(net, saddr, sport, daddr, dport,
-                              dif, &udp_table, NULL);
+                              dif, 0, &udp_table, NULL);
        if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
                sk = NULL;
        return sk;
@@ -587,7 +593,7 @@ EXPORT_SYMBOL_GPL(udp4_lib_lookup);
 static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk,
                                       __be16 loc_port, __be32 loc_addr,
                                       __be16 rmt_port, __be32 rmt_addr,
-                                      int dif, unsigned short hnum)
+                                      int dif, int sdif, unsigned short hnum)
 {
        struct inet_sock *inet = inet_sk(sk);
 
@@ -597,7 +603,8 @@ static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk,
            (inet->inet_dport != rmt_port && inet->inet_dport) ||
            (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) ||
            ipv6_only_sock(sk) ||
-           (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif))
+           (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif &&
+            sk->sk_bound_dev_if != sdif))
                return false;
        if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif))
                return false;
@@ -628,8 +635,8 @@ void __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
        struct net *net = dev_net(skb->dev);
 
        sk = __udp4_lib_lookup(net, iph->daddr, uh->dest,
-                       iph->saddr, uh->source, skb->dev->ifindex, udptable,
-                       NULL);
+                              iph->saddr, uh->source, skb->dev->ifindex, 0,
+                              udptable, NULL);
        if (!sk) {
                __ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
                return; /* No socket for error */
@@ -1953,6 +1960,7 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
        unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10);
        unsigned int offset = offsetof(typeof(*sk), sk_node);
        int dif = skb->dev->ifindex;
+       int sdif = inet_sdif(skb);
        struct hlist_node *node;
        struct sk_buff *nskb;
 
@@ -1967,7 +1975,7 @@ start_lookup:
 
        sk_for_each_entry_offset_rcu(sk, node, &hslot->head, offset) {
                if (!__udp_is_mcast_sock(net, sk, uh->dest, daddr,
-                                        uh->source, saddr, dif, hnum))
+                                        uh->source, saddr, dif, sdif, hnum))
                        continue;
 
                if (!first) {
@@ -2157,7 +2165,7 @@ drop:
 static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
                                                  __be16 loc_port, __be32 loc_addr,
                                                  __be16 rmt_port, __be32 rmt_addr,
-                                                 int dif)
+                                                 int dif, int sdif)
 {
        struct sock *sk, *result;
        unsigned short hnum = ntohs(loc_port);
@@ -2171,7 +2179,7 @@ static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
        result = NULL;
        sk_for_each_rcu(sk, &hslot->head) {
                if (__udp_is_mcast_sock(net, sk, loc_port, loc_addr,
-                                       rmt_port, rmt_addr, dif, hnum)) {
+                                       rmt_port, rmt_addr, dif, sdif, hnum)) {
                        if (result)
                                return NULL;
                        result = sk;
@@ -2216,6 +2224,7 @@ void udp_v4_early_demux(struct sk_buff *skb)
        struct sock *sk = NULL;
        struct dst_entry *dst;
        int dif = skb->dev->ifindex;
+       int sdif = inet_sdif(skb);
        int ours;
 
        /* validate the packet */
@@ -2241,7 +2250,8 @@ void udp_v4_early_demux(struct sk_buff *skb)
                }
 
                sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr,
-                                                  uh->source, iph->saddr, dif);
+                                                  uh->source, iph->saddr,
+                                                  dif, sdif);
        } else if (skb->pkt_type == PACKET_HOST) {
                sk = __udp4_lib_demux_lookup(net, uh->dest, iph->daddr,
                                             uh->source, iph->saddr, dif);
index 4515836d2a3ac309c9305a4ee1ce95fa0b6e26d4..1f07fe1095352979f806aa25ee38dff210799c63 100644 (file)
@@ -45,7 +45,7 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb,
                sk = __udp4_lib_lookup(net,
                                req->id.idiag_src[0], req->id.idiag_sport,
                                req->id.idiag_dst[0], req->id.idiag_dport,
-                               req->id.idiag_if, tbl, NULL);
+                               req->id.idiag_if, 0, tbl, NULL);
 #if IS_ENABLED(CONFIG_IPV6)
        else if (req->sdiag_family == AF_INET6)
                sk = __udp6_lib_lookup(net,
@@ -182,7 +182,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb,
                sk = __udp4_lib_lookup(net,
                                req->id.idiag_dst[0], req->id.idiag_dport,
                                req->id.idiag_src[0], req->id.idiag_sport,
-                               req->id.idiag_if, tbl, NULL);
+                               req->id.idiag_if, 0, tbl, NULL);
 #if IS_ENABLED(CONFIG_IPV6)
        else if (req->sdiag_family == AF_INET6) {
                if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
@@ -190,7 +190,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb,
                        sk = __udp4_lib_lookup(net,
                                        req->id.idiag_dst[3], req->id.idiag_dport,
                                        req->id.idiag_src[3], req->id.idiag_sport,
-                                       req->id.idiag_if, tbl, NULL);
+                                       req->id.idiag_if, 0, tbl, NULL);
 
                else
                        sk = __udp6_lib_lookup(net,