#ifndef __SOCK_DIAG_H__
#define __SOCK_DIAG_H__
+#include <linux/netlink.h>
#include <linux/user_namespace.h>
+#include <net/net_namespace.h>
+#include <net/sock.h>
#include <uapi/linux/sock_diag.h>
struct sk_buff;
struct sock_diag_handler {
__u8 family;
int (*dump)(struct sk_buff *skb, struct nlmsghdr *nlh);
+ int (*get_info)(struct sk_buff *skb, struct sock *sk);
};
int sock_diag_register(const struct sock_diag_handler *h);
int sock_diag_put_filterinfo(bool may_report_filterinfo, struct sock *sk,
struct sk_buff *skb, int attrtype);
+static inline
+enum sknetlink_groups sock_diag_destroy_group(const struct sock *sk)
+{
+ switch (sk->sk_family) {
+ case AF_INET:
+ switch (sk->sk_protocol) {
+ case IPPROTO_TCP:
+ return SKNLGRP_INET_TCP_DESTROY;
+ case IPPROTO_UDP:
+ return SKNLGRP_INET_UDP_DESTROY;
+ default:
+ return SKNLGRP_NONE;
+ }
+ case AF_INET6:
+ switch (sk->sk_protocol) {
+ case IPPROTO_TCP:
+ return SKNLGRP_INET6_TCP_DESTROY;
+ case IPPROTO_UDP:
+ return SKNLGRP_INET6_UDP_DESTROY;
+ default:
+ return SKNLGRP_NONE;
+ }
+ default:
+ return SKNLGRP_NONE;
+ }
+}
+
+static inline
+bool sock_diag_has_destroy_listeners(const struct sock *sk)
+{
+ const struct net *n = sock_net(sk);
+ const enum sknetlink_groups group = sock_diag_destroy_group(sk);
+
+ return group != SKNLGRP_NONE && n->diag_nlsk &&
+ netlink_has_listeners(n->diag_nlsk, group);
+}
+void sock_diag_broadcast_destroy(struct sock *sk);
+
#endif
#include <net/net_namespace.h>
#include <linux/module.h>
#include <net/sock.h>
+#include <linux/kernel.h>
+#include <linux/tcp.h>
+#include <linux/workqueue.h>
#include <linux/inet_diag.h>
#include <linux/sock_diag.h>
static const struct sock_diag_handler *sock_diag_handlers[AF_MAX];
static int (*inet_rcv_compat)(struct sk_buff *skb, struct nlmsghdr *nlh);
static DEFINE_MUTEX(sock_diag_table_mutex);
+static struct workqueue_struct *broadcast_wq;
static u64 sock_gen_cookie(struct sock *sk)
{
}
EXPORT_SYMBOL(sock_diag_put_filterinfo);
+struct broadcast_sk {
+ struct sock *sk;
+ struct work_struct work;
+};
+
+static size_t sock_diag_nlmsg_size(void)
+{
+ return NLMSG_ALIGN(sizeof(struct inet_diag_msg)
+ + nla_total_size(sizeof(u8)) /* INET_DIAG_PROTOCOL */
+ + nla_total_size(sizeof(struct tcp_info))); /* INET_DIAG_INFO */
+}
+
+static void sock_diag_broadcast_destroy_work(struct work_struct *work)
+{
+ struct broadcast_sk *bsk =
+ container_of(work, struct broadcast_sk, work);
+ struct sock *sk = bsk->sk;
+ const struct sock_diag_handler *hndl;
+ struct sk_buff *skb;
+ const enum sknetlink_groups group = sock_diag_destroy_group(sk);
+ int err = -1;
+
+ WARN_ON(group == SKNLGRP_NONE);
+
+ skb = nlmsg_new(sock_diag_nlmsg_size(), GFP_KERNEL);
+ if (!skb)
+ goto out;
+
+ mutex_lock(&sock_diag_table_mutex);
+ hndl = sock_diag_handlers[sk->sk_family];
+ if (hndl && hndl->get_info)
+ err = hndl->get_info(skb, sk);
+ mutex_unlock(&sock_diag_table_mutex);
+
+ if (!err)
+ nlmsg_multicast(sock_net(sk)->diag_nlsk, skb, 0, group,
+ GFP_KERNEL);
+ else
+ kfree_skb(skb);
+out:
+ sk_destruct(sk);
+ kfree(bsk);
+}
+
+void sock_diag_broadcast_destroy(struct sock *sk)
+{
+ /* Note, this function is often called from an interrupt context. */
+ struct broadcast_sk *bsk =
+ kmalloc(sizeof(struct broadcast_sk), GFP_ATOMIC);
+ if (!bsk)
+ return sk_destruct(sk);
+ bsk->sk = sk;
+ INIT_WORK(&bsk->work, sock_diag_broadcast_destroy_work);
+ queue_work(broadcast_wq, &bsk->work);
+}
+
void sock_diag_register_inet_compat(int (*fn)(struct sk_buff *skb, struct nlmsghdr *nlh))
{
mutex_lock(&sock_diag_table_mutex);
mutex_unlock(&sock_diag_mutex);
}
+static int sock_diag_bind(struct net *net, int group)
+{
+ switch (group) {
+ case SKNLGRP_INET_TCP_DESTROY:
+ case SKNLGRP_INET_UDP_DESTROY:
+ if (!sock_diag_handlers[AF_INET])
+ request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
+ NETLINK_SOCK_DIAG, AF_INET);
+ break;
+ case SKNLGRP_INET6_TCP_DESTROY:
+ case SKNLGRP_INET6_UDP_DESTROY:
+ if (!sock_diag_handlers[AF_INET6])
+ request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
+ NETLINK_SOCK_DIAG, AF_INET);
+ break;
+ }
+ return 0;
+}
+
static int __net_init diag_net_init(struct net *net)
{
struct netlink_kernel_cfg cfg = {
+ .groups = SKNLGRP_MAX,
.input = sock_diag_rcv,
+ .bind = sock_diag_bind,
+ .flags = NL_CFG_F_NONROOT_RECV,
};
net->diag_nlsk = netlink_kernel_create(net, NETLINK_SOCK_DIAG, &cfg);
static int __init sock_diag_init(void)
{
+ broadcast_wq = alloc_workqueue("sock_diag_events", 0, 0);
+ BUG_ON(!broadcast_wq);
return register_pernet_subsys(&diag_net_ops);
}
static void __exit sock_diag_exit(void)
{
unregister_pernet_subsys(&diag_net_ops);
+ destroy_workqueue(broadcast_wq);
}
module_init(sock_diag_init);