1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
12 #include <net/protocol.h>
14 #include <net/udp_tunnel.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
19 static DEFINE_SPINLOCK(fou_lock);
20 static LIST_HEAD(fou_list);
26 struct udp_offload udp_offloads;
27 struct list_head list;
33 struct udp_port_cfg udp_config;
36 static inline struct fou *fou_from_sock(struct sock *sk)
38 return sk->sk_user_data;
41 static void fou_recv_pull(struct sk_buff *skb, size_t len)
43 struct iphdr *iph = ip_hdr(skb);
45 /* Remove 'len' bytes from the packet (UDP header and
46 * FOU header if present).
48 iph->tot_len = htons(ntohs(iph->tot_len) - len);
50 skb_postpull_rcsum(skb, udp_hdr(skb), len);
51 skb_reset_transport_header(skb);
54 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
56 struct fou *fou = fou_from_sock(sk);
61 fou_recv_pull(skb, sizeof(struct udphdr));
63 return -fou->protocol;
66 static struct guehdr *gue_remcsum(struct sk_buff *skb, struct guehdr *guehdr,
67 void *data, int hdrlen, u8 ipproto)
70 u16 start = ntohs(pd[0]);
71 u16 offset = ntohs(pd[1]);
77 if (skb->remcsum_offload) {
78 /* Already processed in GRO path */
79 skb->remcsum_offload = 0;
83 if (start > skb->len - hdrlen ||
84 offset > skb->len - hdrlen - sizeof(u16))
87 if (unlikely(skb->ip_summed != CHECKSUM_COMPLETE))
88 __skb_checksum_complete(skb);
90 plen = hdrlen + offset + sizeof(u16);
91 if (!pskb_may_pull(skb, plen))
93 guehdr = (struct guehdr *)&udp_hdr(skb)[1];
95 if (ipproto == IPPROTO_IP && sizeof(struct iphdr) < plen) {
96 struct iphdr *ip = (struct iphdr *)(skb->data + hdrlen);
98 /* If next header happens to be IP we can skip that for the
99 * checksum calculation since the IP header checksum is zero
102 poffset = ip->ihl * 4;
105 csum = csum_sub(skb->csum, skb_checksum(skb, poffset + hdrlen,
106 start - poffset - hdrlen, 0));
108 /* Set derived checksum in packet */
109 psum = (__sum16 *)(skb->data + hdrlen + offset);
110 delta = csum_sub(csum_fold(csum), *psum);
111 *psum = csum_fold(csum);
113 /* Adjust skb->csum since we changed the packet */
114 skb->csum = csum_add(skb->csum, delta);
119 static int gue_control_message(struct sk_buff *skb, struct guehdr *guehdr)
126 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
128 struct fou *fou = fou_from_sock(sk);
129 size_t len, optlen, hdrlen;
130 struct guehdr *guehdr;
137 len = sizeof(struct udphdr) + sizeof(struct guehdr);
138 if (!pskb_may_pull(skb, len))
141 guehdr = (struct guehdr *)&udp_hdr(skb)[1];
143 optlen = guehdr->hlen << 2;
146 if (!pskb_may_pull(skb, len))
149 /* guehdr may change after pull */
150 guehdr = (struct guehdr *)&udp_hdr(skb)[1];
152 hdrlen = sizeof(struct guehdr) + optlen;
154 if (guehdr->version != 0 || validate_gue_flags(guehdr, optlen))
157 hdrlen = sizeof(struct guehdr) + optlen;
159 ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
161 /* Pull UDP header now, skb->data points to guehdr */
162 __skb_pull(skb, sizeof(struct udphdr));
164 /* Pull csum through the guehdr now . This can be used if
165 * there is a remote checksum offload.
167 skb_postpull_rcsum(skb, udp_hdr(skb), len);
171 if (guehdr->flags & GUE_FLAG_PRIV) {
172 __be32 flags = *(__be32 *)(data + doffset);
174 doffset += GUE_LEN_PRIV;
176 if (flags & GUE_PFLAG_REMCSUM) {
177 guehdr = gue_remcsum(skb, guehdr, data + doffset,
178 hdrlen, guehdr->proto_ctype);
184 doffset += GUE_PLEN_REMCSUM;
188 if (unlikely(guehdr->control))
189 return gue_control_message(skb, guehdr);
191 __skb_pull(skb, hdrlen);
192 skb_reset_transport_header(skb);
194 return -guehdr->proto_ctype;
201 static struct sk_buff **fou_gro_receive(struct sk_buff **head,
204 const struct net_offload *ops;
205 struct sk_buff **pp = NULL;
206 u8 proto = NAPI_GRO_CB(skb)->proto;
207 const struct net_offload **offloads;
210 offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
211 ops = rcu_dereference(offloads[proto]);
212 if (!ops || !ops->callbacks.gro_receive)
215 pp = ops->callbacks.gro_receive(head, skb);
223 static int fou_gro_complete(struct sk_buff *skb, int nhoff)
225 const struct net_offload *ops;
226 u8 proto = NAPI_GRO_CB(skb)->proto;
228 const struct net_offload **offloads;
230 udp_tunnel_gro_complete(skb, nhoff);
233 offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
234 ops = rcu_dereference(offloads[proto]);
235 if (WARN_ON(!ops || !ops->callbacks.gro_complete))
238 err = ops->callbacks.gro_complete(skb, nhoff);
246 static struct guehdr *gue_gro_remcsum(struct sk_buff *skb, unsigned int off,
247 struct guehdr *guehdr, void *data,
248 size_t hdrlen, u8 ipproto)
251 u16 start = ntohs(pd[0]);
252 u16 offset = ntohs(pd[1]);
259 if (skb->remcsum_offload)
262 if (start > skb_gro_len(skb) - hdrlen ||
263 offset > skb_gro_len(skb) - hdrlen - sizeof(u16) ||
264 !NAPI_GRO_CB(skb)->csum_valid || skb->remcsum_offload)
267 plen = hdrlen + offset + sizeof(u16);
269 /* Pull checksum that will be written */
270 if (skb_gro_header_hard(skb, off + plen)) {
271 guehdr = skb_gro_header_slow(skb, off + plen, off);
276 ptr = (void *)guehdr + hdrlen;
278 if (ipproto == IPPROTO_IP &&
279 (hdrlen + sizeof(struct iphdr) < plen)) {
280 struct iphdr *ip = (struct iphdr *)(ptr + hdrlen);
282 /* If next header happens to be IP we can skip
283 * that for the checksum calculation since the
284 * IP header checksum is zero if correct.
286 poffset = ip->ihl * 4;
289 csum = csum_sub(NAPI_GRO_CB(skb)->csum,
290 csum_partial(ptr + poffset, start - poffset, 0));
292 /* Set derived checksum in packet */
293 psum = (__sum16 *)(ptr + offset);
294 delta = csum_sub(csum_fold(csum), *psum);
295 *psum = csum_fold(csum);
297 /* Adjust skb->csum since we changed the packet */
298 skb->csum = csum_add(skb->csum, delta);
299 NAPI_GRO_CB(skb)->csum = csum_add(NAPI_GRO_CB(skb)->csum, delta);
301 skb->remcsum_offload = 1;
306 static struct sk_buff **gue_gro_receive(struct sk_buff **head,
309 const struct net_offload **offloads;
310 const struct net_offload *ops;
311 struct sk_buff **pp = NULL;
313 struct guehdr *guehdr;
314 size_t len, optlen, hdrlen, off;
319 off = skb_gro_offset(skb);
320 len = off + sizeof(*guehdr);
322 guehdr = skb_gro_header_fast(skb, off);
323 if (skb_gro_header_hard(skb, len)) {
324 guehdr = skb_gro_header_slow(skb, len, off);
325 if (unlikely(!guehdr))
329 optlen = guehdr->hlen << 2;
332 if (skb_gro_header_hard(skb, len)) {
333 guehdr = skb_gro_header_slow(skb, len, off);
334 if (unlikely(!guehdr))
338 if (unlikely(guehdr->control) || guehdr->version != 0 ||
339 validate_gue_flags(guehdr, optlen))
342 hdrlen = sizeof(*guehdr) + optlen;
344 /* Adjust NAPI_GRO_CB(skb)->csum to account for guehdr,
345 * this is needed if there is a remote checkcsum offload.
347 skb_gro_postpull_rcsum(skb, guehdr, hdrlen);
351 if (guehdr->flags & GUE_FLAG_PRIV) {
352 __be32 flags = *(__be32 *)(data + doffset);
354 doffset += GUE_LEN_PRIV;
356 if (flags & GUE_PFLAG_REMCSUM) {
357 guehdr = gue_gro_remcsum(skb, off, guehdr,
358 data + doffset, hdrlen,
359 guehdr->proto_ctype);
365 doffset += GUE_PLEN_REMCSUM;
369 skb_gro_pull(skb, hdrlen);
373 for (p = *head; p; p = p->next) {
374 const struct guehdr *guehdr2;
376 if (!NAPI_GRO_CB(p)->same_flow)
379 guehdr2 = (struct guehdr *)(p->data + off);
381 /* Compare base GUE header to be equal (covers
382 * hlen, version, proto_ctype, and flags.
384 if (guehdr->word != guehdr2->word) {
385 NAPI_GRO_CB(p)->same_flow = 0;
389 /* Compare optional fields are the same. */
390 if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
391 guehdr->hlen << 2)) {
392 NAPI_GRO_CB(p)->same_flow = 0;
398 offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
399 ops = rcu_dereference(offloads[guehdr->proto_ctype]);
400 if (WARN_ON(!ops || !ops->callbacks.gro_receive))
403 pp = ops->callbacks.gro_receive(head, skb);
408 NAPI_GRO_CB(skb)->flush |= flush;
413 static int gue_gro_complete(struct sk_buff *skb, int nhoff)
415 const struct net_offload **offloads;
416 struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
417 const struct net_offload *ops;
418 unsigned int guehlen;
422 proto = guehdr->proto_ctype;
424 guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
427 offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
428 ops = rcu_dereference(offloads[proto]);
429 if (WARN_ON(!ops || !ops->callbacks.gro_complete))
432 err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
439 static int fou_add_to_port_list(struct fou *fou)
443 spin_lock(&fou_lock);
444 list_for_each_entry(fout, &fou_list, list) {
445 if (fou->port == fout->port) {
446 spin_unlock(&fou_lock);
451 list_add(&fou->list, &fou_list);
452 spin_unlock(&fou_lock);
457 static void fou_release(struct fou *fou)
459 struct socket *sock = fou->sock;
460 struct sock *sk = sock->sk;
462 udp_del_offload(&fou->udp_offloads);
464 list_del(&fou->list);
466 /* Remove hooks into tunnel socket */
467 sk->sk_user_data = NULL;
474 static int fou_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
476 udp_sk(sk)->encap_rcv = fou_udp_recv;
477 fou->protocol = cfg->protocol;
478 fou->udp_offloads.callbacks.gro_receive = fou_gro_receive;
479 fou->udp_offloads.callbacks.gro_complete = fou_gro_complete;
480 fou->udp_offloads.port = cfg->udp_config.local_udp_port;
481 fou->udp_offloads.ipproto = cfg->protocol;
486 static int gue_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
488 udp_sk(sk)->encap_rcv = gue_udp_recv;
489 fou->udp_offloads.callbacks.gro_receive = gue_gro_receive;
490 fou->udp_offloads.callbacks.gro_complete = gue_gro_complete;
491 fou->udp_offloads.port = cfg->udp_config.local_udp_port;
496 static int fou_create(struct net *net, struct fou_cfg *cfg,
497 struct socket **sockp)
499 struct fou *fou = NULL;
501 struct socket *sock = NULL;
504 /* Open UDP socket */
505 err = udp_sock_create(net, &cfg->udp_config, &sock);
509 /* Allocate FOU port structure */
510 fou = kzalloc(sizeof(*fou), GFP_KERNEL);
518 fou->port = cfg->udp_config.local_udp_port;
520 /* Initial for fou type */
522 case FOU_ENCAP_DIRECT:
523 err = fou_encap_init(sk, fou, cfg);
528 err = gue_encap_init(sk, fou, cfg);
537 udp_sk(sk)->encap_type = 1;
540 sk->sk_user_data = fou;
543 udp_set_convert_csum(sk, true);
545 sk->sk_allocation = GFP_ATOMIC;
547 if (cfg->udp_config.family == AF_INET) {
548 err = udp_add_offload(&fou->udp_offloads);
553 err = fou_add_to_port_list(fou);
570 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
573 u16 port = cfg->udp_config.local_udp_port;
576 spin_lock(&fou_lock);
577 list_for_each_entry(fou, &fou_list, list) {
578 if (fou->port == port) {
579 udp_del_offload(&fou->udp_offloads);
585 spin_unlock(&fou_lock);
590 static struct genl_family fou_nl_family = {
591 .id = GENL_ID_GENERATE,
593 .name = FOU_GENL_NAME,
594 .version = FOU_GENL_VERSION,
595 .maxattr = FOU_ATTR_MAX,
599 static struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
600 [FOU_ATTR_PORT] = { .type = NLA_U16, },
601 [FOU_ATTR_AF] = { .type = NLA_U8, },
602 [FOU_ATTR_IPPROTO] = { .type = NLA_U8, },
603 [FOU_ATTR_TYPE] = { .type = NLA_U8, },
606 static int parse_nl_config(struct genl_info *info,
609 memset(cfg, 0, sizeof(*cfg));
611 cfg->udp_config.family = AF_INET;
613 if (info->attrs[FOU_ATTR_AF]) {
614 u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
616 if (family != AF_INET && family != AF_INET6)
619 cfg->udp_config.family = family;
622 if (info->attrs[FOU_ATTR_PORT]) {
623 u16 port = nla_get_u16(info->attrs[FOU_ATTR_PORT]);
625 cfg->udp_config.local_udp_port = port;
628 if (info->attrs[FOU_ATTR_IPPROTO])
629 cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
631 if (info->attrs[FOU_ATTR_TYPE])
632 cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
637 static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info)
642 err = parse_nl_config(info, &cfg);
646 return fou_create(&init_net, &cfg, NULL);
649 static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
653 parse_nl_config(info, &cfg);
655 return fou_destroy(&init_net, &cfg);
658 static const struct genl_ops fou_nl_ops[] = {
661 .doit = fou_nl_cmd_add_port,
662 .policy = fou_nl_policy,
663 .flags = GENL_ADMIN_PERM,
667 .doit = fou_nl_cmd_rm_port,
668 .policy = fou_nl_policy,
669 .flags = GENL_ADMIN_PERM,
673 size_t fou_encap_hlen(struct ip_tunnel_encap *e)
675 return sizeof(struct udphdr);
677 EXPORT_SYMBOL(fou_encap_hlen);
679 size_t gue_encap_hlen(struct ip_tunnel_encap *e)
682 bool need_priv = false;
684 len = sizeof(struct udphdr) + sizeof(struct guehdr);
686 if (e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) {
687 len += GUE_PLEN_REMCSUM;
691 len += need_priv ? GUE_LEN_PRIV : 0;
695 EXPORT_SYMBOL(gue_encap_hlen);
697 static void fou_build_udp(struct sk_buff *skb, struct ip_tunnel_encap *e,
698 struct flowi4 *fl4, u8 *protocol, __be16 sport)
702 skb_push(skb, sizeof(struct udphdr));
703 skb_reset_transport_header(skb);
709 uh->len = htons(skb->len);
711 udp_set_csum(!(e->flags & TUNNEL_ENCAP_FLAG_CSUM), skb,
712 fl4->saddr, fl4->daddr, skb->len);
714 *protocol = IPPROTO_UDP;
717 int fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
718 u8 *protocol, struct flowi4 *fl4)
720 bool csum = !!(e->flags & TUNNEL_ENCAP_FLAG_CSUM);
721 int type = csum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
724 skb = iptunnel_handle_offloads(skb, csum, type);
729 sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
731 fou_build_udp(skb, e, fl4, protocol, sport);
735 EXPORT_SYMBOL(fou_build_header);
737 int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
738 u8 *protocol, struct flowi4 *fl4)
740 bool csum = !!(e->flags & TUNNEL_ENCAP_FLAG_CSUM);
741 int type = csum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
742 struct guehdr *guehdr;
743 size_t hdrlen, optlen = 0;
746 bool need_priv = false;
748 if ((e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) &&
749 skb->ip_summed == CHECKSUM_PARTIAL) {
751 optlen += GUE_PLEN_REMCSUM;
752 type |= SKB_GSO_TUNNEL_REMCSUM;
756 optlen += need_priv ? GUE_LEN_PRIV : 0;
758 skb = iptunnel_handle_offloads(skb, csum, type);
763 /* Get source port (based on flow hash) before skb_push */
764 sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
767 hdrlen = sizeof(struct guehdr) + optlen;
769 skb_push(skb, hdrlen);
771 guehdr = (struct guehdr *)skb->data;
775 guehdr->hlen = optlen >> 2;
777 guehdr->proto_ctype = *protocol;
782 __be32 *flags = data;
784 guehdr->flags |= GUE_FLAG_PRIV;
786 data += GUE_LEN_PRIV;
788 if (type & SKB_GSO_TUNNEL_REMCSUM) {
789 u16 csum_start = skb_checksum_start_offset(skb);
792 if (csum_start < hdrlen)
795 csum_start -= hdrlen;
796 pd[0] = htons(csum_start);
797 pd[1] = htons(csum_start + skb->csum_offset);
799 if (!skb_is_gso(skb)) {
800 skb->ip_summed = CHECKSUM_NONE;
801 skb->encapsulation = 0;
804 *flags |= GUE_PFLAG_REMCSUM;
805 data += GUE_PLEN_REMCSUM;
810 fou_build_udp(skb, e, fl4, protocol, sport);
814 EXPORT_SYMBOL(gue_build_header);
816 #ifdef CONFIG_NET_FOU_IP_TUNNELS
818 static const struct ip_tunnel_encap_ops __read_mostly fou_iptun_ops = {
819 .encap_hlen = fou_encap_hlen,
820 .build_header = fou_build_header,
823 static const struct ip_tunnel_encap_ops __read_mostly gue_iptun_ops = {
824 .encap_hlen = gue_encap_hlen,
825 .build_header = gue_build_header,
828 static int ip_tunnel_encap_add_fou_ops(void)
832 ret = ip_tunnel_encap_add_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
834 pr_err("can't add fou ops\n");
838 ret = ip_tunnel_encap_add_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
840 pr_err("can't add gue ops\n");
841 ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
848 static void ip_tunnel_encap_del_fou_ops(void)
850 ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
851 ip_tunnel_encap_del_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
856 static int ip_tunnel_encap_add_fou_ops(void)
861 static void ip_tunnel_encap_del_fou_ops(void)
867 static int __init fou_init(void)
871 ret = genl_register_family_with_ops(&fou_nl_family,
877 ret = ip_tunnel_encap_add_fou_ops();
879 genl_unregister_family(&fou_nl_family);
885 static void __exit fou_fini(void)
887 struct fou *fou, *next;
889 ip_tunnel_encap_del_fou_ops();
891 genl_unregister_family(&fou_nl_family);
893 /* Close all the FOU sockets */
895 spin_lock(&fou_lock);
896 list_for_each_entry_safe(fou, next, &fou_list, list)
898 spin_unlock(&fou_lock);
901 module_init(fou_init);
902 module_exit(fou_fini);
904 MODULE_LICENSE("GPL");