#include <linux/mutex.h>
#include <linux/vmalloc.h>
#include <linux/if_arp.h>
+#include <linux/rhashtable.h>
#include <asm/cacheflush.h>
+#include <linux/hash.h>
#include <net/net_namespace.h>
#include <net/sock.h>
#define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock));
+/* Protects netlink socket hash table mutations */
+DEFINE_MUTEX(nl_sk_hash_lock);
+
+static int lockdep_nl_sk_hash_is_held(void)
+{
+#ifdef CONFIG_LOCKDEP
+ return (debug_locks) ? lockdep_is_held(&nl_sk_hash_lock) : 1;
+#else
+ return 1;
+#endif
+}
+
static ATOMIC_NOTIFIER_HEAD(netlink_chain);
static DEFINE_SPINLOCK(netlink_tap_lock);
return group ? 1 << (group - 1) : 0;
}
-static inline struct hlist_head *nl_portid_hashfn(struct nl_portid_hash *hash, u32 portid)
-{
- return &hash->table[jhash_1word(portid, hash->rnd) & hash->mask];
-}
-
int netlink_add_tap(struct netlink_tap *nt)
{
if (unlikely(nt->dev->type != ARPHRD_NETLINK))
wake_up(&nl_table_wait);
}
-static bool netlink_compare(struct net *net, struct sock *sk)
-{
- return net_eq(sock_net(sk), net);
-}
-
-static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
+struct netlink_compare_arg
{
- struct netlink_table *table = &nl_table[protocol];
- struct nl_portid_hash *hash = &table->hash;
- struct hlist_head *head;
- struct sock *sk;
-
- read_lock(&nl_table_lock);
- head = nl_portid_hashfn(hash, portid);
- sk_for_each(sk, head) {
- if (table->compare(net, sk) &&
- (nlk_sk(sk)->portid == portid)) {
- sock_hold(sk);
- goto found;
- }
- }
- sk = NULL;
-found:
- read_unlock(&nl_table_lock);
- return sk;
-}
+ struct net *net;
+ u32 portid;
+};
-static struct hlist_head *nl_portid_hash_zalloc(size_t size)
+static bool netlink_compare(void *ptr, void *arg)
{
- if (size <= PAGE_SIZE)
- return kzalloc(size, GFP_ATOMIC);
- else
- return (struct hlist_head *)
- __get_free_pages(GFP_ATOMIC | __GFP_ZERO,
- get_order(size));
-}
+ struct netlink_compare_arg *x = arg;
+ struct sock *sk = ptr;
-static void nl_portid_hash_free(struct hlist_head *table, size_t size)
-{
- if (size <= PAGE_SIZE)
- kfree(table);
- else
- free_pages((unsigned long)table, get_order(size));
+ return nlk_sk(sk)->portid == x->portid &&
+ net_eq(sock_net(sk), x->net);
}
-static int nl_portid_hash_rehash(struct nl_portid_hash *hash, int grow)
+static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
+ struct net *net)
{
- unsigned int omask, mask, shift;
- size_t osize, size;
- struct hlist_head *otable, *table;
- int i;
-
- omask = mask = hash->mask;
- osize = size = (mask + 1) * sizeof(*table);
- shift = hash->shift;
-
- if (grow) {
- if (++shift > hash->max_shift)
- return 0;
- mask = mask * 2 + 1;
- size *= 2;
- }
+ struct netlink_compare_arg arg = {
+ .net = net,
+ .portid = portid,
+ };
+ u32 hash;
- table = nl_portid_hash_zalloc(size);
- if (!table)
- return 0;
+ hash = rhashtable_hashfn(&table->hash, &portid, sizeof(portid));
- otable = hash->table;
- hash->table = table;
- hash->mask = mask;
- hash->shift = shift;
- get_random_bytes(&hash->rnd, sizeof(hash->rnd));
-
- for (i = 0; i <= omask; i++) {
- struct sock *sk;
- struct hlist_node *tmp;
-
- sk_for_each_safe(sk, tmp, &otable[i])
- __sk_add_node(sk, nl_portid_hashfn(hash, nlk_sk(sk)->portid));
- }
-
- nl_portid_hash_free(otable, osize);
- hash->rehash_time = jiffies + 10 * 60 * HZ;
- return 1;
+ return rhashtable_lookup_compare(&table->hash, hash,
+ &netlink_compare, &arg);
}
-static inline int nl_portid_hash_dilute(struct nl_portid_hash *hash, int len)
+static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
{
- int avg = hash->entries >> hash->shift;
-
- if (unlikely(avg > 1) && nl_portid_hash_rehash(hash, 1))
- return 1;
+ struct netlink_table *table = &nl_table[protocol];
+ struct sock *sk;
- if (unlikely(len > avg) && time_after(jiffies, hash->rehash_time)) {
- nl_portid_hash_rehash(hash, 0);
- return 1;
- }
+ rcu_read_lock();
+ sk = __netlink_lookup(table, portid, net);
+ if (sk)
+ sock_hold(sk);
+ rcu_read_unlock();
- return 0;
+ return sk;
}
static const struct proto_ops netlink_ops;
static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
{
struct netlink_table *table = &nl_table[sk->sk_protocol];
- struct nl_portid_hash *hash = &table->hash;
- struct hlist_head *head;
int err = -EADDRINUSE;
- struct sock *osk;
- int len;
- netlink_table_grab();
- head = nl_portid_hashfn(hash, portid);
- len = 0;
- sk_for_each(osk, head) {
- if (table->compare(net, osk) &&
- (nlk_sk(osk)->portid == portid))
- break;
- len++;
- }
- if (osk)
+ mutex_lock(&nl_sk_hash_lock);
+ if (__netlink_lookup(table, portid, net))
goto err;
err = -EBUSY;
goto err;
err = -ENOMEM;
- if (BITS_PER_LONG > 32 && unlikely(hash->entries >= UINT_MAX))
+ if (BITS_PER_LONG > 32 && unlikely(table->hash.nelems >= UINT_MAX))
goto err;
- if (len && nl_portid_hash_dilute(hash, len))
- head = nl_portid_hashfn(hash, portid);
- hash->entries++;
nlk_sk(sk)->portid = portid;
- sk_add_node(sk, head);
+ sock_hold(sk);
+ rhashtable_insert(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL);
err = 0;
-
err:
- netlink_table_ungrab();
+ mutex_unlock(&nl_sk_hash_lock);
return err;
}
static void netlink_remove(struct sock *sk)
{
+ struct netlink_table *table;
+
+ mutex_lock(&nl_sk_hash_lock);
+ table = &nl_table[sk->sk_protocol];
+ if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL)) {
+ WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
+ __sock_put(sk);
+ }
+ mutex_unlock(&nl_sk_hash_lock);
+
netlink_table_grab();
- if (sk_del_node_init(sk))
- nl_table[sk->sk_protocol].hash.entries--;
if (nlk_sk(sk)->subscriptions)
__sk_del_bind_node(sk);
netlink_table_ungrab();
}
netlink_table_ungrab();
+ /* Wait for readers to complete */
+ synchronize_net();
+
kfree(nlk->groups);
nlk->groups = NULL;
struct sock *sk = sock->sk;
struct net *net = sock_net(sk);
struct netlink_table *table = &nl_table[sk->sk_protocol];
- struct nl_portid_hash *hash = &table->hash;
- struct hlist_head *head;
- struct sock *osk;
s32 portid = task_tgid_vnr(current);
int err;
static s32 rover = -4097;
retry:
cond_resched();
- netlink_table_grab();
- head = nl_portid_hashfn(hash, portid);
- sk_for_each(osk, head) {
- if (!table->compare(net, osk))
- continue;
- if (nlk_sk(osk)->portid == portid) {
- /* Bind collision, search negative portid values. */
- portid = rover--;
- if (rover > -4097)
- rover = -4097;
- netlink_table_ungrab();
- goto retry;
- }
+ rcu_read_lock();
+ if (__netlink_lookup(table, portid, net)) {
+ /* Bind collision, search negative portid values. */
+ portid = rover--;
+ if (rover > -4097)
+ rover = -4097;
+ rcu_read_unlock();
+ goto retry;
}
- netlink_table_ungrab();
+ rcu_read_unlock();
err = netlink_insert(sk, net, portid);
if (err == -EADDRINUSE)
{
struct nl_seq_iter *iter = seq->private;
int i, j;
+ struct netlink_sock *nlk;
struct sock *s;
loff_t off = 0;
for (i = 0; i < MAX_LINKS; i++) {
- struct nl_portid_hash *hash = &nl_table[i].hash;
+ struct rhashtable *ht = &nl_table[i].hash;
+ const struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
+
+ for (j = 0; j < tbl->size; j++) {
+ rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
+ s = (struct sock *)nlk;
- for (j = 0; j <= hash->mask; j++) {
- sk_for_each(s, &hash->table[j]) {
if (sock_net(s) != seq_file_net(seq))
continue;
if (off == pos) {
}
static void *netlink_seq_start(struct seq_file *seq, loff_t *pos)
- __acquires(nl_table_lock)
{
- read_lock(&nl_table_lock);
+ rcu_read_lock();
return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN;
}
static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
{
- struct sock *s;
+ struct netlink_sock *nlk;
struct nl_seq_iter *iter;
struct net *net;
int i, j;
net = seq_file_net(seq);
iter = seq->private;
- s = v;
- do {
- s = sk_next(s);
- } while (s && !nl_table[s->sk_protocol].compare(net, s));
- if (s)
- return s;
+ nlk = v;
+
+ rht_for_each_entry_rcu(nlk, nlk->node.next, node)
+ if (net_eq(sock_net((struct sock *)nlk), net))
+ return nlk;
i = iter->link;
j = iter->hash_idx + 1;
do {
- struct nl_portid_hash *hash = &nl_table[i].hash;
-
- for (; j <= hash->mask; j++) {
- s = sk_head(&hash->table[j]);
+ struct rhashtable *ht = &nl_table[i].hash;
+ const struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
- while (s && !nl_table[s->sk_protocol].compare(net, s))
- s = sk_next(s);
- if (s) {
- iter->link = i;
- iter->hash_idx = j;
- return s;
+ for (; j < tbl->size; j++) {
+ rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
+ if (net_eq(sock_net((struct sock *)nlk), net)) {
+ iter->link = i;
+ iter->hash_idx = j;
+ return nlk;
+ }
}
}
}
static void netlink_seq_stop(struct seq_file *seq, void *v)
- __releases(nl_table_lock)
{
- read_unlock(&nl_table_lock);
+ rcu_read_unlock();
}
static int __init netlink_proto_init(void)
{
int i;
- unsigned long limit;
- unsigned int order;
int err = proto_register(&netlink_proto, 0);
+ struct rhashtable_params ht_params = {
+ .head_offset = offsetof(struct netlink_sock, node),
+ .key_offset = offsetof(struct netlink_sock, portid),
+ .key_len = sizeof(u32), /* portid */
+ .hashfn = arch_fast_hash,
+ .max_shift = 16, /* 64K */
+ .grow_decision = rht_grow_above_75,
+ .shrink_decision = rht_shrink_below_30,
+ .mutex_is_held = lockdep_nl_sk_hash_is_held,
+ };
if (err != 0)
goto out;
if (!nl_table)
goto panic;
- if (totalram_pages >= (128 * 1024))
- limit = totalram_pages >> (21 - PAGE_SHIFT);
- else
- limit = totalram_pages >> (23 - PAGE_SHIFT);
-
- order = get_bitmask_order(limit) - 1 + PAGE_SHIFT;
- limit = (1UL << order) / sizeof(struct hlist_head);
- order = get_bitmask_order(min(limit, (unsigned long)UINT_MAX)) - 1;
-
for (i = 0; i < MAX_LINKS; i++) {
- struct nl_portid_hash *hash = &nl_table[i].hash;
-
- hash->table = nl_portid_hash_zalloc(1 * sizeof(*hash->table));
- if (!hash->table) {
- while (i-- > 0)
- nl_portid_hash_free(nl_table[i].hash.table,
- 1 * sizeof(*hash->table));
+ if (rhashtable_init(&nl_table[i].hash, &ht_params) < 0) {
+ while (--i > 0)
+ rhashtable_destroy(&nl_table[i].hash);
kfree(nl_table);
goto panic;
}
- hash->max_shift = order;
- hash->shift = 0;
- hash->mask = 0;
- hash->rehash_time = jiffies;
-
- nl_table[i].compare = netlink_compare;
}
INIT_LIST_HEAD(&netlink_tap_all);