]> Git Repo - J-linux.git/blob - arch/x86/net/bpf_jit_comp.c
Merge tag 'for-netdev' of https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf...
[J-linux.git] / arch / x86 / net / bpf_jit_comp.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * BPF JIT compiler
4  *
5  * Copyright (C) 2011-2013 Eric Dumazet ([email protected])
6  * Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
7  */
8 #include <linux/netdevice.h>
9 #include <linux/filter.h>
10 #include <linux/if_vlan.h>
11 #include <linux/bpf.h>
12 #include <linux/memory.h>
13 #include <linux/sort.h>
14 #include <asm/extable.h>
15 #include <asm/ftrace.h>
16 #include <asm/set_memory.h>
17 #include <asm/nospec-branch.h>
18 #include <asm/text-patching.h>
19 #include <asm/unwind.h>
20 #include <asm/cfi.h>
21
22 static bool all_callee_regs_used[4] = {true, true, true, true};
23
24 static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
25 {
26         if (len == 1)
27                 *ptr = bytes;
28         else if (len == 2)
29                 *(u16 *)ptr = bytes;
30         else {
31                 *(u32 *)ptr = bytes;
32                 barrier();
33         }
34         return ptr + len;
35 }
36
37 #define EMIT(bytes, len) \
38         do { prog = emit_code(prog, bytes, len); } while (0)
39
40 #define EMIT1(b1)               EMIT(b1, 1)
41 #define EMIT2(b1, b2)           EMIT((b1) + ((b2) << 8), 2)
42 #define EMIT3(b1, b2, b3)       EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
43 #define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
44
45 #define EMIT1_off32(b1, off) \
46         do { EMIT1(b1); EMIT(off, 4); } while (0)
47 #define EMIT2_off32(b1, b2, off) \
48         do { EMIT2(b1, b2); EMIT(off, 4); } while (0)
49 #define EMIT3_off32(b1, b2, b3, off) \
50         do { EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
51 #define EMIT4_off32(b1, b2, b3, b4, off) \
52         do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
53
54 #ifdef CONFIG_X86_KERNEL_IBT
55 #define EMIT_ENDBR()            EMIT(gen_endbr(), 4)
56 #define EMIT_ENDBR_POISON()     EMIT(gen_endbr_poison(), 4)
57 #else
58 #define EMIT_ENDBR()
59 #define EMIT_ENDBR_POISON()
60 #endif
61
62 static bool is_imm8(int value)
63 {
64         return value <= 127 && value >= -128;
65 }
66
67 static bool is_simm32(s64 value)
68 {
69         return value == (s64)(s32)value;
70 }
71
72 static bool is_uimm32(u64 value)
73 {
74         return value == (u64)(u32)value;
75 }
76
77 /* mov dst, src */
78 #define EMIT_mov(DST, SRC)                                                               \
79         do {                                                                             \
80                 if (DST != SRC)                                                          \
81                         EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
82         } while (0)
83
84 static int bpf_size_to_x86_bytes(int bpf_size)
85 {
86         if (bpf_size == BPF_W)
87                 return 4;
88         else if (bpf_size == BPF_H)
89                 return 2;
90         else if (bpf_size == BPF_B)
91                 return 1;
92         else if (bpf_size == BPF_DW)
93                 return 4; /* imm32 */
94         else
95                 return 0;
96 }
97
98 /*
99  * List of x86 cond jumps opcodes (. + s8)
100  * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
101  */
102 #define X86_JB  0x72
103 #define X86_JAE 0x73
104 #define X86_JE  0x74
105 #define X86_JNE 0x75
106 #define X86_JBE 0x76
107 #define X86_JA  0x77
108 #define X86_JL  0x7C
109 #define X86_JGE 0x7D
110 #define X86_JLE 0x7E
111 #define X86_JG  0x7F
112
113 /* Pick a register outside of BPF range for JIT internal work */
114 #define AUX_REG (MAX_BPF_JIT_REG + 1)
115 #define X86_REG_R9 (MAX_BPF_JIT_REG + 2)
116
117 /*
118  * The following table maps BPF registers to x86-64 registers.
119  *
120  * x86-64 register R12 is unused, since if used as base address
121  * register in load/store instructions, it always needs an
122  * extra byte of encoding and is callee saved.
123  *
124  * x86-64 register R9 is not used by BPF programs, but can be used by BPF
125  * trampoline. x86-64 register R10 is used for blinding (if enabled).
126  */
127 static const int reg2hex[] = {
128         [BPF_REG_0] = 0,  /* RAX */
129         [BPF_REG_1] = 7,  /* RDI */
130         [BPF_REG_2] = 6,  /* RSI */
131         [BPF_REG_3] = 2,  /* RDX */
132         [BPF_REG_4] = 1,  /* RCX */
133         [BPF_REG_5] = 0,  /* R8  */
134         [BPF_REG_6] = 3,  /* RBX callee saved */
135         [BPF_REG_7] = 5,  /* R13 callee saved */
136         [BPF_REG_8] = 6,  /* R14 callee saved */
137         [BPF_REG_9] = 7,  /* R15 callee saved */
138         [BPF_REG_FP] = 5, /* RBP readonly */
139         [BPF_REG_AX] = 2, /* R10 temp register */
140         [AUX_REG] = 3,    /* R11 temp register */
141         [X86_REG_R9] = 1, /* R9 register, 6th function argument */
142 };
143
144 static const int reg2pt_regs[] = {
145         [BPF_REG_0] = offsetof(struct pt_regs, ax),
146         [BPF_REG_1] = offsetof(struct pt_regs, di),
147         [BPF_REG_2] = offsetof(struct pt_regs, si),
148         [BPF_REG_3] = offsetof(struct pt_regs, dx),
149         [BPF_REG_4] = offsetof(struct pt_regs, cx),
150         [BPF_REG_5] = offsetof(struct pt_regs, r8),
151         [BPF_REG_6] = offsetof(struct pt_regs, bx),
152         [BPF_REG_7] = offsetof(struct pt_regs, r13),
153         [BPF_REG_8] = offsetof(struct pt_regs, r14),
154         [BPF_REG_9] = offsetof(struct pt_regs, r15),
155 };
156
157 /*
158  * is_ereg() == true if BPF register 'reg' maps to x86-64 r8..r15
159  * which need extra byte of encoding.
160  * rax,rcx,...,rbp have simpler encoding
161  */
162 static bool is_ereg(u32 reg)
163 {
164         return (1 << reg) & (BIT(BPF_REG_5) |
165                              BIT(AUX_REG) |
166                              BIT(BPF_REG_7) |
167                              BIT(BPF_REG_8) |
168                              BIT(BPF_REG_9) |
169                              BIT(X86_REG_R9) |
170                              BIT(BPF_REG_AX));
171 }
172
173 /*
174  * is_ereg_8l() == true if BPF register 'reg' is mapped to access x86-64
175  * lower 8-bit registers dil,sil,bpl,spl,r8b..r15b, which need extra byte
176  * of encoding. al,cl,dl,bl have simpler encoding.
177  */
178 static bool is_ereg_8l(u32 reg)
179 {
180         return is_ereg(reg) ||
181             (1 << reg) & (BIT(BPF_REG_1) |
182                           BIT(BPF_REG_2) |
183                           BIT(BPF_REG_FP));
184 }
185
186 static bool is_axreg(u32 reg)
187 {
188         return reg == BPF_REG_0;
189 }
190
191 /* Add modifiers if 'reg' maps to x86-64 registers R8..R15 */
192 static u8 add_1mod(u8 byte, u32 reg)
193 {
194         if (is_ereg(reg))
195                 byte |= 1;
196         return byte;
197 }
198
199 static u8 add_2mod(u8 byte, u32 r1, u32 r2)
200 {
201         if (is_ereg(r1))
202                 byte |= 1;
203         if (is_ereg(r2))
204                 byte |= 4;
205         return byte;
206 }
207
208 /* Encode 'dst_reg' register into x86-64 opcode 'byte' */
209 static u8 add_1reg(u8 byte, u32 dst_reg)
210 {
211         return byte + reg2hex[dst_reg];
212 }
213
214 /* Encode 'dst_reg' and 'src_reg' registers into x86-64 opcode 'byte' */
215 static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
216 {
217         return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
218 }
219
220 /* Some 1-byte opcodes for binary ALU operations */
221 static u8 simple_alu_opcodes[] = {
222         [BPF_ADD] = 0x01,
223         [BPF_SUB] = 0x29,
224         [BPF_AND] = 0x21,
225         [BPF_OR] = 0x09,
226         [BPF_XOR] = 0x31,
227         [BPF_LSH] = 0xE0,
228         [BPF_RSH] = 0xE8,
229         [BPF_ARSH] = 0xF8,
230 };
231
232 static void jit_fill_hole(void *area, unsigned int size)
233 {
234         /* Fill whole space with INT3 instructions */
235         memset(area, 0xcc, size);
236 }
237
238 int bpf_arch_text_invalidate(void *dst, size_t len)
239 {
240         return IS_ERR_OR_NULL(text_poke_set(dst, 0xcc, len));
241 }
242
243 struct jit_context {
244         int cleanup_addr; /* Epilogue code offset */
245
246         /*
247          * Program specific offsets of labels in the code; these rely on the
248          * JIT doing at least 2 passes, recording the position on the first
249          * pass, only to generate the correct offset on the second pass.
250          */
251         int tail_call_direct_label;
252         int tail_call_indirect_label;
253 };
254
255 /* Maximum number of bytes emitted while JITing one eBPF insn */
256 #define BPF_MAX_INSN_SIZE       128
257 #define BPF_INSN_SAFETY         64
258
259 /* Number of bytes emit_patch() needs to generate instructions */
260 #define X86_PATCH_SIZE          5
261 /* Number of bytes that will be skipped on tailcall */
262 #define X86_TAIL_CALL_OFFSET    (11 + ENDBR_INSN_SIZE)
263
264 static void push_r12(u8 **pprog)
265 {
266         u8 *prog = *pprog;
267
268         EMIT2(0x41, 0x54);   /* push r12 */
269         *pprog = prog;
270 }
271
272 static void push_callee_regs(u8 **pprog, bool *callee_regs_used)
273 {
274         u8 *prog = *pprog;
275
276         if (callee_regs_used[0])
277                 EMIT1(0x53);         /* push rbx */
278         if (callee_regs_used[1])
279                 EMIT2(0x41, 0x55);   /* push r13 */
280         if (callee_regs_used[2])
281                 EMIT2(0x41, 0x56);   /* push r14 */
282         if (callee_regs_used[3])
283                 EMIT2(0x41, 0x57);   /* push r15 */
284         *pprog = prog;
285 }
286
287 static void pop_r12(u8 **pprog)
288 {
289         u8 *prog = *pprog;
290
291         EMIT2(0x41, 0x5C);   /* pop r12 */
292         *pprog = prog;
293 }
294
295 static void pop_callee_regs(u8 **pprog, bool *callee_regs_used)
296 {
297         u8 *prog = *pprog;
298
299         if (callee_regs_used[3])
300                 EMIT2(0x41, 0x5F);   /* pop r15 */
301         if (callee_regs_used[2])
302                 EMIT2(0x41, 0x5E);   /* pop r14 */
303         if (callee_regs_used[1])
304                 EMIT2(0x41, 0x5D);   /* pop r13 */
305         if (callee_regs_used[0])
306                 EMIT1(0x5B);         /* pop rbx */
307         *pprog = prog;
308 }
309
310 /*
311  * Emit the various CFI preambles, see asm/cfi.h and the comments about FineIBT
312  * in arch/x86/kernel/alternative.c
313  */
314
315 static void emit_fineibt(u8 **pprog, u32 hash)
316 {
317         u8 *prog = *pprog;
318
319         EMIT_ENDBR();
320         EMIT3_off32(0x41, 0x81, 0xea, hash);            /* subl $hash, %r10d    */
321         EMIT2(0x74, 0x07);                              /* jz.d8 +7             */
322         EMIT2(0x0f, 0x0b);                              /* ud2                  */
323         EMIT1(0x90);                                    /* nop                  */
324         EMIT_ENDBR_POISON();
325
326         *pprog = prog;
327 }
328
329 static void emit_kcfi(u8 **pprog, u32 hash)
330 {
331         u8 *prog = *pprog;
332
333         EMIT1_off32(0xb8, hash);                        /* movl $hash, %eax     */
334 #ifdef CONFIG_CALL_PADDING
335         EMIT1(0x90);
336         EMIT1(0x90);
337         EMIT1(0x90);
338         EMIT1(0x90);
339         EMIT1(0x90);
340         EMIT1(0x90);
341         EMIT1(0x90);
342         EMIT1(0x90);
343         EMIT1(0x90);
344         EMIT1(0x90);
345         EMIT1(0x90);
346 #endif
347         EMIT_ENDBR();
348
349         *pprog = prog;
350 }
351
352 static void emit_cfi(u8 **pprog, u32 hash)
353 {
354         u8 *prog = *pprog;
355
356         switch (cfi_mode) {
357         case CFI_FINEIBT:
358                 emit_fineibt(&prog, hash);
359                 break;
360
361         case CFI_KCFI:
362                 emit_kcfi(&prog, hash);
363                 break;
364
365         default:
366                 EMIT_ENDBR();
367                 break;
368         }
369
370         *pprog = prog;
371 }
372
373 /*
374  * Emit x86-64 prologue code for BPF program.
375  * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes
376  * while jumping to another program
377  */
378 static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
379                           bool tail_call_reachable, bool is_subprog,
380                           bool is_exception_cb)
381 {
382         u8 *prog = *pprog;
383
384         emit_cfi(&prog, is_subprog ? cfi_bpf_subprog_hash : cfi_bpf_hash);
385         /* BPF trampoline can be made to work without these nops,
386          * but let's waste 5 bytes for now and optimize later
387          */
388         memcpy(prog, x86_nops[5], X86_PATCH_SIZE);
389         prog += X86_PATCH_SIZE;
390         if (!ebpf_from_cbpf) {
391                 if (tail_call_reachable && !is_subprog)
392                         /* When it's the entry of the whole tailcall context,
393                          * zeroing rax means initialising tail_call_cnt.
394                          */
395                         EMIT2(0x31, 0xC0); /* xor eax, eax */
396                 else
397                         /* Keep the same instruction layout. */
398                         EMIT2(0x66, 0x90); /* nop2 */
399         }
400         /* Exception callback receives FP as third parameter */
401         if (is_exception_cb) {
402                 EMIT3(0x48, 0x89, 0xF4); /* mov rsp, rsi */
403                 EMIT3(0x48, 0x89, 0xD5); /* mov rbp, rdx */
404                 /* The main frame must have exception_boundary as true, so we
405                  * first restore those callee-saved regs from stack, before
406                  * reusing the stack frame.
407                  */
408                 pop_callee_regs(&prog, all_callee_regs_used);
409                 pop_r12(&prog);
410                 /* Reset the stack frame. */
411                 EMIT3(0x48, 0x89, 0xEC); /* mov rsp, rbp */
412         } else {
413                 EMIT1(0x55);             /* push rbp */
414                 EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
415         }
416
417         /* X86_TAIL_CALL_OFFSET is here */
418         EMIT_ENDBR();
419
420         /* sub rsp, rounded_stack_depth */
421         if (stack_depth)
422                 EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
423         if (tail_call_reachable)
424                 EMIT1(0x50);         /* push rax */
425         *pprog = prog;
426 }
427
428 static int emit_patch(u8 **pprog, void *func, void *ip, u8 opcode)
429 {
430         u8 *prog = *pprog;
431         s64 offset;
432
433         offset = func - (ip + X86_PATCH_SIZE);
434         if (!is_simm32(offset)) {
435                 pr_err("Target call %p is out of range\n", func);
436                 return -ERANGE;
437         }
438         EMIT1_off32(opcode, offset);
439         *pprog = prog;
440         return 0;
441 }
442
443 static int emit_call(u8 **pprog, void *func, void *ip)
444 {
445         return emit_patch(pprog, func, ip, 0xE8);
446 }
447
448 static int emit_rsb_call(u8 **pprog, void *func, void *ip)
449 {
450         OPTIMIZER_HIDE_VAR(func);
451         x86_call_depth_emit_accounting(pprog, func);
452         return emit_patch(pprog, func, ip, 0xE8);
453 }
454
455 static int emit_jump(u8 **pprog, void *func, void *ip)
456 {
457         return emit_patch(pprog, func, ip, 0xE9);
458 }
459
460 static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
461                                 void *old_addr, void *new_addr)
462 {
463         const u8 *nop_insn = x86_nops[5];
464         u8 old_insn[X86_PATCH_SIZE];
465         u8 new_insn[X86_PATCH_SIZE];
466         u8 *prog;
467         int ret;
468
469         memcpy(old_insn, nop_insn, X86_PATCH_SIZE);
470         if (old_addr) {
471                 prog = old_insn;
472                 ret = t == BPF_MOD_CALL ?
473                       emit_call(&prog, old_addr, ip) :
474                       emit_jump(&prog, old_addr, ip);
475                 if (ret)
476                         return ret;
477         }
478
479         memcpy(new_insn, nop_insn, X86_PATCH_SIZE);
480         if (new_addr) {
481                 prog = new_insn;
482                 ret = t == BPF_MOD_CALL ?
483                       emit_call(&prog, new_addr, ip) :
484                       emit_jump(&prog, new_addr, ip);
485                 if (ret)
486                         return ret;
487         }
488
489         ret = -EBUSY;
490         mutex_lock(&text_mutex);
491         if (memcmp(ip, old_insn, X86_PATCH_SIZE))
492                 goto out;
493         ret = 1;
494         if (memcmp(ip, new_insn, X86_PATCH_SIZE)) {
495                 text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
496                 ret = 0;
497         }
498 out:
499         mutex_unlock(&text_mutex);
500         return ret;
501 }
502
503 int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
504                        void *old_addr, void *new_addr)
505 {
506         if (!is_kernel_text((long)ip) &&
507             !is_bpf_text_address((long)ip))
508                 /* BPF poking in modules is not supported */
509                 return -EINVAL;
510
511         /*
512          * See emit_prologue(), for IBT builds the trampoline hook is preceded
513          * with an ENDBR instruction.
514          */
515         if (is_endbr(*(u32 *)ip))
516                 ip += ENDBR_INSN_SIZE;
517
518         return __bpf_arch_text_poke(ip, t, old_addr, new_addr);
519 }
520
521 #define EMIT_LFENCE()   EMIT3(0x0F, 0xAE, 0xE8)
522
523 static void emit_indirect_jump(u8 **pprog, int reg, u8 *ip)
524 {
525         u8 *prog = *pprog;
526
527         if (cpu_feature_enabled(X86_FEATURE_RETPOLINE_LFENCE)) {
528                 EMIT_LFENCE();
529                 EMIT2(0xFF, 0xE0 + reg);
530         } else if (cpu_feature_enabled(X86_FEATURE_RETPOLINE)) {
531                 OPTIMIZER_HIDE_VAR(reg);
532                 if (cpu_feature_enabled(X86_FEATURE_CALL_DEPTH))
533                         emit_jump(&prog, &__x86_indirect_jump_thunk_array[reg], ip);
534                 else
535                         emit_jump(&prog, &__x86_indirect_thunk_array[reg], ip);
536         } else {
537                 EMIT2(0xFF, 0xE0 + reg);        /* jmp *%\reg */
538                 if (IS_ENABLED(CONFIG_RETPOLINE) || IS_ENABLED(CONFIG_SLS))
539                         EMIT1(0xCC);            /* int3 */
540         }
541
542         *pprog = prog;
543 }
544
545 static void emit_return(u8 **pprog, u8 *ip)
546 {
547         u8 *prog = *pprog;
548
549         if (cpu_feature_enabled(X86_FEATURE_RETHUNK)) {
550                 emit_jump(&prog, x86_return_thunk, ip);
551         } else {
552                 EMIT1(0xC3);            /* ret */
553                 if (IS_ENABLED(CONFIG_SLS))
554                         EMIT1(0xCC);    /* int3 */
555         }
556
557         *pprog = prog;
558 }
559
560 /*
561  * Generate the following code:
562  *
563  * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
564  *   if (index >= array->map.max_entries)
565  *     goto out;
566  *   if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
567  *     goto out;
568  *   prog = array->ptrs[index];
569  *   if (prog == NULL)
570  *     goto out;
571  *   goto *(prog->bpf_func + prologue_size);
572  * out:
573  */
574 static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
575                                         u8 **pprog, bool *callee_regs_used,
576                                         u32 stack_depth, u8 *ip,
577                                         struct jit_context *ctx)
578 {
579         int tcc_off = -4 - round_up(stack_depth, 8);
580         u8 *prog = *pprog, *start = *pprog;
581         int offset;
582
583         /*
584          * rdi - pointer to ctx
585          * rsi - pointer to bpf_array
586          * rdx - index in bpf_array
587          */
588
589         /*
590          * if (index >= array->map.max_entries)
591          *      goto out;
592          */
593         EMIT2(0x89, 0xD2);                        /* mov edx, edx */
594         EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
595               offsetof(struct bpf_array, map.max_entries));
596
597         offset = ctx->tail_call_indirect_label - (prog + 2 - start);
598         EMIT2(X86_JBE, offset);                   /* jbe out */
599
600         /*
601          * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
602          *      goto out;
603          */
604         EMIT2_off32(0x8B, 0x85, tcc_off);         /* mov eax, dword ptr [rbp - tcc_off] */
605         EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
606
607         offset = ctx->tail_call_indirect_label - (prog + 2 - start);
608         EMIT2(X86_JAE, offset);                   /* jae out */
609         EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
610         EMIT2_off32(0x89, 0x85, tcc_off);         /* mov dword ptr [rbp - tcc_off], eax */
611
612         /* prog = array->ptrs[index]; */
613         EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6,       /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */
614                     offsetof(struct bpf_array, ptrs));
615
616         /*
617          * if (prog == NULL)
618          *      goto out;
619          */
620         EMIT3(0x48, 0x85, 0xC9);                  /* test rcx,rcx */
621
622         offset = ctx->tail_call_indirect_label - (prog + 2 - start);
623         EMIT2(X86_JE, offset);                    /* je out */
624
625         if (bpf_prog->aux->exception_boundary) {
626                 pop_callee_regs(&prog, all_callee_regs_used);
627                 pop_r12(&prog);
628         } else {
629                 pop_callee_regs(&prog, callee_regs_used);
630         }
631
632         EMIT1(0x58);                              /* pop rax */
633         if (stack_depth)
634                 EMIT3_off32(0x48, 0x81, 0xC4,     /* add rsp, sd */
635                             round_up(stack_depth, 8));
636
637         /* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */
638         EMIT4(0x48, 0x8B, 0x49,                   /* mov rcx, qword ptr [rcx + 32] */
639               offsetof(struct bpf_prog, bpf_func));
640         EMIT4(0x48, 0x83, 0xC1,                   /* add rcx, X86_TAIL_CALL_OFFSET */
641               X86_TAIL_CALL_OFFSET);
642         /*
643          * Now we're ready to jump into next BPF program
644          * rdi == ctx (1st arg)
645          * rcx == prog->bpf_func + X86_TAIL_CALL_OFFSET
646          */
647         emit_indirect_jump(&prog, 1 /* rcx */, ip + (prog - start));
648
649         /* out: */
650         ctx->tail_call_indirect_label = prog - start;
651         *pprog = prog;
652 }
653
654 static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog,
655                                       struct bpf_jit_poke_descriptor *poke,
656                                       u8 **pprog, u8 *ip,
657                                       bool *callee_regs_used, u32 stack_depth,
658                                       struct jit_context *ctx)
659 {
660         int tcc_off = -4 - round_up(stack_depth, 8);
661         u8 *prog = *pprog, *start = *pprog;
662         int offset;
663
664         /*
665          * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
666          *      goto out;
667          */
668         EMIT2_off32(0x8B, 0x85, tcc_off);             /* mov eax, dword ptr [rbp - tcc_off] */
669         EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);         /* cmp eax, MAX_TAIL_CALL_CNT */
670
671         offset = ctx->tail_call_direct_label - (prog + 2 - start);
672         EMIT2(X86_JAE, offset);                       /* jae out */
673         EMIT3(0x83, 0xC0, 0x01);                      /* add eax, 1 */
674         EMIT2_off32(0x89, 0x85, tcc_off);             /* mov dword ptr [rbp - tcc_off], eax */
675
676         poke->tailcall_bypass = ip + (prog - start);
677         poke->adj_off = X86_TAIL_CALL_OFFSET;
678         poke->tailcall_target = ip + ctx->tail_call_direct_label - X86_PATCH_SIZE;
679         poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
680
681         emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
682                   poke->tailcall_bypass);
683
684         if (bpf_prog->aux->exception_boundary) {
685                 pop_callee_regs(&prog, all_callee_regs_used);
686                 pop_r12(&prog);
687         } else {
688                 pop_callee_regs(&prog, callee_regs_used);
689         }
690
691         EMIT1(0x58);                                  /* pop rax */
692         if (stack_depth)
693                 EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
694
695         memcpy(prog, x86_nops[5], X86_PATCH_SIZE);
696         prog += X86_PATCH_SIZE;
697
698         /* out: */
699         ctx->tail_call_direct_label = prog - start;
700
701         *pprog = prog;
702 }
703
704 static void bpf_tail_call_direct_fixup(struct bpf_prog *prog)
705 {
706         struct bpf_jit_poke_descriptor *poke;
707         struct bpf_array *array;
708         struct bpf_prog *target;
709         int i, ret;
710
711         for (i = 0; i < prog->aux->size_poke_tab; i++) {
712                 poke = &prog->aux->poke_tab[i];
713                 if (poke->aux && poke->aux != prog->aux)
714                         continue;
715
716                 WARN_ON_ONCE(READ_ONCE(poke->tailcall_target_stable));
717
718                 if (poke->reason != BPF_POKE_REASON_TAIL_CALL)
719                         continue;
720
721                 array = container_of(poke->tail_call.map, struct bpf_array, map);
722                 mutex_lock(&array->aux->poke_mutex);
723                 target = array->ptrs[poke->tail_call.key];
724                 if (target) {
725                         ret = __bpf_arch_text_poke(poke->tailcall_target,
726                                                    BPF_MOD_JUMP, NULL,
727                                                    (u8 *)target->bpf_func +
728                                                    poke->adj_off);
729                         BUG_ON(ret < 0);
730                         ret = __bpf_arch_text_poke(poke->tailcall_bypass,
731                                                    BPF_MOD_JUMP,
732                                                    (u8 *)poke->tailcall_target +
733                                                    X86_PATCH_SIZE, NULL);
734                         BUG_ON(ret < 0);
735                 }
736                 WRITE_ONCE(poke->tailcall_target_stable, true);
737                 mutex_unlock(&array->aux->poke_mutex);
738         }
739 }
740
741 static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
742                            u32 dst_reg, const u32 imm32)
743 {
744         u8 *prog = *pprog;
745         u8 b1, b2, b3;
746
747         /*
748          * Optimization: if imm32 is positive, use 'mov %eax, imm32'
749          * (which zero-extends imm32) to save 2 bytes.
750          */
751         if (sign_propagate && (s32)imm32 < 0) {
752                 /* 'mov %rax, imm32' sign extends imm32 */
753                 b1 = add_1mod(0x48, dst_reg);
754                 b2 = 0xC7;
755                 b3 = 0xC0;
756                 EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
757                 goto done;
758         }
759
760         /*
761          * Optimization: if imm32 is zero, use 'xor %eax, %eax'
762          * to save 3 bytes.
763          */
764         if (imm32 == 0) {
765                 if (is_ereg(dst_reg))
766                         EMIT1(add_2mod(0x40, dst_reg, dst_reg));
767                 b2 = 0x31; /* xor */
768                 b3 = 0xC0;
769                 EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
770                 goto done;
771         }
772
773         /* mov %eax, imm32 */
774         if (is_ereg(dst_reg))
775                 EMIT1(add_1mod(0x40, dst_reg));
776         EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
777 done:
778         *pprog = prog;
779 }
780
781 static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
782                            const u32 imm32_hi, const u32 imm32_lo)
783 {
784         u8 *prog = *pprog;
785
786         if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
787                 /*
788                  * For emitting plain u32, where sign bit must not be
789                  * propagated LLVM tends to load imm64 over mov32
790                  * directly, so save couple of bytes by just doing
791                  * 'mov %eax, imm32' instead.
792                  */
793                 emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
794         } else {
795                 /* movabsq rax, imm64 */
796                 EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
797                 EMIT(imm32_lo, 4);
798                 EMIT(imm32_hi, 4);
799         }
800
801         *pprog = prog;
802 }
803
804 static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
805 {
806         u8 *prog = *pprog;
807
808         if (is64) {
809                 /* mov dst, src */
810                 EMIT_mov(dst_reg, src_reg);
811         } else {
812                 /* mov32 dst, src */
813                 if (is_ereg(dst_reg) || is_ereg(src_reg))
814                         EMIT1(add_2mod(0x40, dst_reg, src_reg));
815                 EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
816         }
817
818         *pprog = prog;
819 }
820
821 static void emit_movsx_reg(u8 **pprog, int num_bits, bool is64, u32 dst_reg,
822                            u32 src_reg)
823 {
824         u8 *prog = *pprog;
825
826         if (is64) {
827                 /* movs[b,w,l]q dst, src */
828                 if (num_bits == 8)
829                         EMIT4(add_2mod(0x48, src_reg, dst_reg), 0x0f, 0xbe,
830                               add_2reg(0xC0, src_reg, dst_reg));
831                 else if (num_bits == 16)
832                         EMIT4(add_2mod(0x48, src_reg, dst_reg), 0x0f, 0xbf,
833                               add_2reg(0xC0, src_reg, dst_reg));
834                 else if (num_bits == 32)
835                         EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x63,
836                               add_2reg(0xC0, src_reg, dst_reg));
837         } else {
838                 /* movs[b,w]l dst, src */
839                 if (num_bits == 8) {
840                         EMIT4(add_2mod(0x40, src_reg, dst_reg), 0x0f, 0xbe,
841                               add_2reg(0xC0, src_reg, dst_reg));
842                 } else if (num_bits == 16) {
843                         if (is_ereg(dst_reg) || is_ereg(src_reg))
844                                 EMIT1(add_2mod(0x40, src_reg, dst_reg));
845                         EMIT3(add_2mod(0x0f, src_reg, dst_reg), 0xbf,
846                               add_2reg(0xC0, src_reg, dst_reg));
847                 }
848         }
849
850         *pprog = prog;
851 }
852
853 /* Emit the suffix (ModR/M etc) for addressing *(ptr_reg + off) and val_reg */
854 static void emit_insn_suffix(u8 **pprog, u32 ptr_reg, u32 val_reg, int off)
855 {
856         u8 *prog = *pprog;
857
858         if (is_imm8(off)) {
859                 /* 1-byte signed displacement.
860                  *
861                  * If off == 0 we could skip this and save one extra byte, but
862                  * special case of x86 R13 which always needs an offset is not
863                  * worth the hassle
864                  */
865                 EMIT2(add_2reg(0x40, ptr_reg, val_reg), off);
866         } else {
867                 /* 4-byte signed displacement */
868                 EMIT1_off32(add_2reg(0x80, ptr_reg, val_reg), off);
869         }
870         *pprog = prog;
871 }
872
873 /*
874  * Emit a REX byte if it will be necessary to address these registers
875  */
876 static void maybe_emit_mod(u8 **pprog, u32 dst_reg, u32 src_reg, bool is64)
877 {
878         u8 *prog = *pprog;
879
880         if (is64)
881                 EMIT1(add_2mod(0x48, dst_reg, src_reg));
882         else if (is_ereg(dst_reg) || is_ereg(src_reg))
883                 EMIT1(add_2mod(0x40, dst_reg, src_reg));
884         *pprog = prog;
885 }
886
887 /*
888  * Similar version of maybe_emit_mod() for a single register
889  */
890 static void maybe_emit_1mod(u8 **pprog, u32 reg, bool is64)
891 {
892         u8 *prog = *pprog;
893
894         if (is64)
895                 EMIT1(add_1mod(0x48, reg));
896         else if (is_ereg(reg))
897                 EMIT1(add_1mod(0x40, reg));
898         *pprog = prog;
899 }
900
901 /* LDX: dst_reg = *(u8*)(src_reg + off) */
902 static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
903 {
904         u8 *prog = *pprog;
905
906         switch (size) {
907         case BPF_B:
908                 /* Emit 'movzx rax, byte ptr [rax + off]' */
909                 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
910                 break;
911         case BPF_H:
912                 /* Emit 'movzx rax, word ptr [rax + off]' */
913                 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
914                 break;
915         case BPF_W:
916                 /* Emit 'mov eax, dword ptr [rax+0x14]' */
917                 if (is_ereg(dst_reg) || is_ereg(src_reg))
918                         EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
919                 else
920                         EMIT1(0x8B);
921                 break;
922         case BPF_DW:
923                 /* Emit 'mov rax, qword ptr [rax+0x14]' */
924                 EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
925                 break;
926         }
927         emit_insn_suffix(&prog, src_reg, dst_reg, off);
928         *pprog = prog;
929 }
930
931 /* LDSX: dst_reg = *(s8*)(src_reg + off) */
932 static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
933 {
934         u8 *prog = *pprog;
935
936         switch (size) {
937         case BPF_B:
938                 /* Emit 'movsx rax, byte ptr [rax + off]' */
939                 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE);
940                 break;
941         case BPF_H:
942                 /* Emit 'movsx rax, word ptr [rax + off]' */
943                 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF);
944                 break;
945         case BPF_W:
946                 /* Emit 'movsx rax, dword ptr [rax+0x14]' */
947                 EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63);
948                 break;
949         }
950         emit_insn_suffix(&prog, src_reg, dst_reg, off);
951         *pprog = prog;
952 }
953
954 /* STX: *(u8*)(dst_reg + off) = src_reg */
955 static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
956 {
957         u8 *prog = *pprog;
958
959         switch (size) {
960         case BPF_B:
961                 /* Emit 'mov byte ptr [rax + off], al' */
962                 if (is_ereg(dst_reg) || is_ereg_8l(src_reg))
963                         /* Add extra byte for eregs or SIL,DIL,BPL in src_reg */
964                         EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
965                 else
966                         EMIT1(0x88);
967                 break;
968         case BPF_H:
969                 if (is_ereg(dst_reg) || is_ereg(src_reg))
970                         EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
971                 else
972                         EMIT2(0x66, 0x89);
973                 break;
974         case BPF_W:
975                 if (is_ereg(dst_reg) || is_ereg(src_reg))
976                         EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
977                 else
978                         EMIT1(0x89);
979                 break;
980         case BPF_DW:
981                 EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
982                 break;
983         }
984         emit_insn_suffix(&prog, dst_reg, src_reg, off);
985         *pprog = prog;
986 }
987
988 static int emit_atomic(u8 **pprog, u8 atomic_op,
989                        u32 dst_reg, u32 src_reg, s16 off, u8 bpf_size)
990 {
991         u8 *prog = *pprog;
992
993         EMIT1(0xF0); /* lock prefix */
994
995         maybe_emit_mod(&prog, dst_reg, src_reg, bpf_size == BPF_DW);
996
997         /* emit opcode */
998         switch (atomic_op) {
999         case BPF_ADD:
1000         case BPF_AND:
1001         case BPF_OR:
1002         case BPF_XOR:
1003                 /* lock *(u32/u64*)(dst_reg + off) <op>= src_reg */
1004                 EMIT1(simple_alu_opcodes[atomic_op]);
1005                 break;
1006         case BPF_ADD | BPF_FETCH:
1007                 /* src_reg = atomic_fetch_add(dst_reg + off, src_reg); */
1008                 EMIT2(0x0F, 0xC1);
1009                 break;
1010         case BPF_XCHG:
1011                 /* src_reg = atomic_xchg(dst_reg + off, src_reg); */
1012                 EMIT1(0x87);
1013                 break;
1014         case BPF_CMPXCHG:
1015                 /* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
1016                 EMIT2(0x0F, 0xB1);
1017                 break;
1018         default:
1019                 pr_err("bpf_jit: unknown atomic opcode %02x\n", atomic_op);
1020                 return -EFAULT;
1021         }
1022
1023         emit_insn_suffix(&prog, dst_reg, src_reg, off);
1024
1025         *pprog = prog;
1026         return 0;
1027 }
1028
1029 bool ex_handler_bpf(const struct exception_table_entry *x, struct pt_regs *regs)
1030 {
1031         u32 reg = x->fixup >> 8;
1032
1033         /* jump over faulting load and clear dest register */
1034         *(unsigned long *)((void *)regs + reg) = 0;
1035         regs->ip += x->fixup & 0xff;
1036         return true;
1037 }
1038
1039 static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
1040                              bool *regs_used, bool *tail_call_seen)
1041 {
1042         int i;
1043
1044         for (i = 1; i <= insn_cnt; i++, insn++) {
1045                 if (insn->code == (BPF_JMP | BPF_TAIL_CALL))
1046                         *tail_call_seen = true;
1047                 if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
1048                         regs_used[0] = true;
1049                 if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
1050                         regs_used[1] = true;
1051                 if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
1052                         regs_used[2] = true;
1053                 if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
1054                         regs_used[3] = true;
1055         }
1056 }
1057
1058 static void emit_nops(u8 **pprog, int len)
1059 {
1060         u8 *prog = *pprog;
1061         int i, noplen;
1062
1063         while (len > 0) {
1064                 noplen = len;
1065
1066                 if (noplen > ASM_NOP_MAX)
1067                         noplen = ASM_NOP_MAX;
1068
1069                 for (i = 0; i < noplen; i++)
1070                         EMIT1(x86_nops[noplen][i]);
1071                 len -= noplen;
1072         }
1073
1074         *pprog = prog;
1075 }
1076
1077 /* emit the 3-byte VEX prefix
1078  *
1079  * r: same as rex.r, extra bit for ModRM reg field
1080  * x: same as rex.x, extra bit for SIB index field
1081  * b: same as rex.b, extra bit for ModRM r/m, or SIB base
1082  * m: opcode map select, encoding escape bytes e.g. 0x0f38
1083  * w: same as rex.w (32 bit or 64 bit) or opcode specific
1084  * src_reg2: additional source reg (encoded as BPF reg)
1085  * l: vector length (128 bit or 256 bit) or reserved
1086  * pp: opcode prefix (none, 0x66, 0xf2 or 0xf3)
1087  */
1088 static void emit_3vex(u8 **pprog, bool r, bool x, bool b, u8 m,
1089                       bool w, u8 src_reg2, bool l, u8 pp)
1090 {
1091         u8 *prog = *pprog;
1092         const u8 b0 = 0xc4; /* first byte of 3-byte VEX prefix */
1093         u8 b1, b2;
1094         u8 vvvv = reg2hex[src_reg2];
1095
1096         /* reg2hex gives only the lower 3 bit of vvvv */
1097         if (is_ereg(src_reg2))
1098                 vvvv |= 1 << 3;
1099
1100         /*
1101          * 2nd byte of 3-byte VEX prefix
1102          * ~ means bit inverted encoding
1103          *
1104          *    7                           0
1105          *  +---+---+---+---+---+---+---+---+
1106          *  |~R |~X |~B |         m         |
1107          *  +---+---+---+---+---+---+---+---+
1108          */
1109         b1 = (!r << 7) | (!x << 6) | (!b << 5) | (m & 0x1f);
1110         /*
1111          * 3rd byte of 3-byte VEX prefix
1112          *
1113          *    7                           0
1114          *  +---+---+---+---+---+---+---+---+
1115          *  | W |     ~vvvv     | L |   pp  |
1116          *  +---+---+---+---+---+---+---+---+
1117          */
1118         b2 = (w << 7) | ((~vvvv & 0xf) << 3) | (l << 2) | (pp & 3);
1119
1120         EMIT3(b0, b1, b2);
1121         *pprog = prog;
1122 }
1123
1124 /* emit BMI2 shift instruction */
1125 static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op)
1126 {
1127         u8 *prog = *pprog;
1128         bool r = is_ereg(dst_reg);
1129         u8 m = 2; /* escape code 0f38 */
1130
1131         emit_3vex(&prog, r, false, r, m, is64, src_reg, false, op);
1132         EMIT2(0xf7, add_2reg(0xC0, dst_reg, dst_reg));
1133         *pprog = prog;
1134 }
1135
1136 #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp)))
1137
1138 /* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
1139 #define RESTORE_TAIL_CALL_CNT(stack)                            \
1140         EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8)
1141
1142 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image,
1143                   int oldproglen, struct jit_context *ctx, bool jmp_padding)
1144 {
1145         bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
1146         struct bpf_insn *insn = bpf_prog->insnsi;
1147         bool callee_regs_used[4] = {};
1148         int insn_cnt = bpf_prog->len;
1149         bool tail_call_seen = false;
1150         bool seen_exit = false;
1151         u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
1152         int i, excnt = 0;
1153         int ilen, proglen = 0;
1154         u8 *prog = temp;
1155         int err;
1156
1157         detect_reg_usage(insn, insn_cnt, callee_regs_used,
1158                          &tail_call_seen);
1159
1160         /* tail call's presence in current prog implies it is reachable */
1161         tail_call_reachable |= tail_call_seen;
1162
1163         emit_prologue(&prog, bpf_prog->aux->stack_depth,
1164                       bpf_prog_was_classic(bpf_prog), tail_call_reachable,
1165                       bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb);
1166         /* Exception callback will clobber callee regs for its own use, and
1167          * restore the original callee regs from main prog's stack frame.
1168          */
1169         if (bpf_prog->aux->exception_boundary) {
1170                 /* We also need to save r12, which is not mapped to any BPF
1171                  * register, as we throw after entry into the kernel, which may
1172                  * overwrite r12.
1173                  */
1174                 push_r12(&prog);
1175                 push_callee_regs(&prog, all_callee_regs_used);
1176         } else {
1177                 push_callee_regs(&prog, callee_regs_used);
1178         }
1179
1180         ilen = prog - temp;
1181         if (rw_image)
1182                 memcpy(rw_image + proglen, temp, ilen);
1183         proglen += ilen;
1184         addrs[0] = proglen;
1185         prog = temp;
1186
1187         for (i = 1; i <= insn_cnt; i++, insn++) {
1188                 const s32 imm32 = insn->imm;
1189                 u32 dst_reg = insn->dst_reg;
1190                 u32 src_reg = insn->src_reg;
1191                 u8 b2 = 0, b3 = 0;
1192                 u8 *start_of_ldx;
1193                 s64 jmp_offset;
1194                 s16 insn_off;
1195                 u8 jmp_cond;
1196                 u8 *func;
1197                 int nops;
1198
1199                 switch (insn->code) {
1200                         /* ALU */
1201                 case BPF_ALU | BPF_ADD | BPF_X:
1202                 case BPF_ALU | BPF_SUB | BPF_X:
1203                 case BPF_ALU | BPF_AND | BPF_X:
1204                 case BPF_ALU | BPF_OR | BPF_X:
1205                 case BPF_ALU | BPF_XOR | BPF_X:
1206                 case BPF_ALU64 | BPF_ADD | BPF_X:
1207                 case BPF_ALU64 | BPF_SUB | BPF_X:
1208                 case BPF_ALU64 | BPF_AND | BPF_X:
1209                 case BPF_ALU64 | BPF_OR | BPF_X:
1210                 case BPF_ALU64 | BPF_XOR | BPF_X:
1211                         maybe_emit_mod(&prog, dst_reg, src_reg,
1212                                        BPF_CLASS(insn->code) == BPF_ALU64);
1213                         b2 = simple_alu_opcodes[BPF_OP(insn->code)];
1214                         EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
1215                         break;
1216
1217                 case BPF_ALU64 | BPF_MOV | BPF_X:
1218                 case BPF_ALU | BPF_MOV | BPF_X:
1219                         if (insn->off == 0)
1220                                 emit_mov_reg(&prog,
1221                                              BPF_CLASS(insn->code) == BPF_ALU64,
1222                                              dst_reg, src_reg);
1223                         else
1224                                 emit_movsx_reg(&prog, insn->off,
1225                                                BPF_CLASS(insn->code) == BPF_ALU64,
1226                                                dst_reg, src_reg);
1227                         break;
1228
1229                         /* neg dst */
1230                 case BPF_ALU | BPF_NEG:
1231                 case BPF_ALU64 | BPF_NEG:
1232                         maybe_emit_1mod(&prog, dst_reg,
1233                                         BPF_CLASS(insn->code) == BPF_ALU64);
1234                         EMIT2(0xF7, add_1reg(0xD8, dst_reg));
1235                         break;
1236
1237                 case BPF_ALU | BPF_ADD | BPF_K:
1238                 case BPF_ALU | BPF_SUB | BPF_K:
1239                 case BPF_ALU | BPF_AND | BPF_K:
1240                 case BPF_ALU | BPF_OR | BPF_K:
1241                 case BPF_ALU | BPF_XOR | BPF_K:
1242                 case BPF_ALU64 | BPF_ADD | BPF_K:
1243                 case BPF_ALU64 | BPF_SUB | BPF_K:
1244                 case BPF_ALU64 | BPF_AND | BPF_K:
1245                 case BPF_ALU64 | BPF_OR | BPF_K:
1246                 case BPF_ALU64 | BPF_XOR | BPF_K:
1247                         maybe_emit_1mod(&prog, dst_reg,
1248                                         BPF_CLASS(insn->code) == BPF_ALU64);
1249
1250                         /*
1251                          * b3 holds 'normal' opcode, b2 short form only valid
1252                          * in case dst is eax/rax.
1253                          */
1254                         switch (BPF_OP(insn->code)) {
1255                         case BPF_ADD:
1256                                 b3 = 0xC0;
1257                                 b2 = 0x05;
1258                                 break;
1259                         case BPF_SUB:
1260                                 b3 = 0xE8;
1261                                 b2 = 0x2D;
1262                                 break;
1263                         case BPF_AND:
1264                                 b3 = 0xE0;
1265                                 b2 = 0x25;
1266                                 break;
1267                         case BPF_OR:
1268                                 b3 = 0xC8;
1269                                 b2 = 0x0D;
1270                                 break;
1271                         case BPF_XOR:
1272                                 b3 = 0xF0;
1273                                 b2 = 0x35;
1274                                 break;
1275                         }
1276
1277                         if (is_imm8(imm32))
1278                                 EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
1279                         else if (is_axreg(dst_reg))
1280                                 EMIT1_off32(b2, imm32);
1281                         else
1282                                 EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
1283                         break;
1284
1285                 case BPF_ALU64 | BPF_MOV | BPF_K:
1286                 case BPF_ALU | BPF_MOV | BPF_K:
1287                         emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
1288                                        dst_reg, imm32);
1289                         break;
1290
1291                 case BPF_LD | BPF_IMM | BPF_DW:
1292                         emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
1293                         insn++;
1294                         i++;
1295                         break;
1296
1297                         /* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
1298                 case BPF_ALU | BPF_MOD | BPF_X:
1299                 case BPF_ALU | BPF_DIV | BPF_X:
1300                 case BPF_ALU | BPF_MOD | BPF_K:
1301                 case BPF_ALU | BPF_DIV | BPF_K:
1302                 case BPF_ALU64 | BPF_MOD | BPF_X:
1303                 case BPF_ALU64 | BPF_DIV | BPF_X:
1304                 case BPF_ALU64 | BPF_MOD | BPF_K:
1305                 case BPF_ALU64 | BPF_DIV | BPF_K: {
1306                         bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;
1307
1308                         if (dst_reg != BPF_REG_0)
1309                                 EMIT1(0x50); /* push rax */
1310                         if (dst_reg != BPF_REG_3)
1311                                 EMIT1(0x52); /* push rdx */
1312
1313                         if (BPF_SRC(insn->code) == BPF_X) {
1314                                 if (src_reg == BPF_REG_0 ||
1315                                     src_reg == BPF_REG_3) {
1316                                         /* mov r11, src_reg */
1317                                         EMIT_mov(AUX_REG, src_reg);
1318                                         src_reg = AUX_REG;
1319                                 }
1320                         } else {
1321                                 /* mov r11, imm32 */
1322                                 EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
1323                                 src_reg = AUX_REG;
1324                         }
1325
1326                         if (dst_reg != BPF_REG_0)
1327                                 /* mov rax, dst_reg */
1328                                 emit_mov_reg(&prog, is64, BPF_REG_0, dst_reg);
1329
1330                         if (insn->off == 0) {
1331                                 /*
1332                                  * xor edx, edx
1333                                  * equivalent to 'xor rdx, rdx', but one byte less
1334                                  */
1335                                 EMIT2(0x31, 0xd2);
1336
1337                                 /* div src_reg */
1338                                 maybe_emit_1mod(&prog, src_reg, is64);
1339                                 EMIT2(0xF7, add_1reg(0xF0, src_reg));
1340                         } else {
1341                                 if (BPF_CLASS(insn->code) == BPF_ALU)
1342                                         EMIT1(0x99); /* cdq */
1343                                 else
1344                                         EMIT2(0x48, 0x99); /* cqo */
1345
1346                                 /* idiv src_reg */
1347                                 maybe_emit_1mod(&prog, src_reg, is64);
1348                                 EMIT2(0xF7, add_1reg(0xF8, src_reg));
1349                         }
1350
1351                         if (BPF_OP(insn->code) == BPF_MOD &&
1352                             dst_reg != BPF_REG_3)
1353                                 /* mov dst_reg, rdx */
1354                                 emit_mov_reg(&prog, is64, dst_reg, BPF_REG_3);
1355                         else if (BPF_OP(insn->code) == BPF_DIV &&
1356                                  dst_reg != BPF_REG_0)
1357                                 /* mov dst_reg, rax */
1358                                 emit_mov_reg(&prog, is64, dst_reg, BPF_REG_0);
1359
1360                         if (dst_reg != BPF_REG_3)
1361                                 EMIT1(0x5A); /* pop rdx */
1362                         if (dst_reg != BPF_REG_0)
1363                                 EMIT1(0x58); /* pop rax */
1364                         break;
1365                 }
1366
1367                 case BPF_ALU | BPF_MUL | BPF_K:
1368                 case BPF_ALU64 | BPF_MUL | BPF_K:
1369                         maybe_emit_mod(&prog, dst_reg, dst_reg,
1370                                        BPF_CLASS(insn->code) == BPF_ALU64);
1371
1372                         if (is_imm8(imm32))
1373                                 /* imul dst_reg, dst_reg, imm8 */
1374                                 EMIT3(0x6B, add_2reg(0xC0, dst_reg, dst_reg),
1375                                       imm32);
1376                         else
1377                                 /* imul dst_reg, dst_reg, imm32 */
1378                                 EMIT2_off32(0x69,
1379                                             add_2reg(0xC0, dst_reg, dst_reg),
1380                                             imm32);
1381                         break;
1382
1383                 case BPF_ALU | BPF_MUL | BPF_X:
1384                 case BPF_ALU64 | BPF_MUL | BPF_X:
1385                         maybe_emit_mod(&prog, src_reg, dst_reg,
1386                                        BPF_CLASS(insn->code) == BPF_ALU64);
1387
1388                         /* imul dst_reg, src_reg */
1389                         EMIT3(0x0F, 0xAF, add_2reg(0xC0, src_reg, dst_reg));
1390                         break;
1391
1392                         /* Shifts */
1393                 case BPF_ALU | BPF_LSH | BPF_K:
1394                 case BPF_ALU | BPF_RSH | BPF_K:
1395                 case BPF_ALU | BPF_ARSH | BPF_K:
1396                 case BPF_ALU64 | BPF_LSH | BPF_K:
1397                 case BPF_ALU64 | BPF_RSH | BPF_K:
1398                 case BPF_ALU64 | BPF_ARSH | BPF_K:
1399                         maybe_emit_1mod(&prog, dst_reg,
1400                                         BPF_CLASS(insn->code) == BPF_ALU64);
1401
1402                         b3 = simple_alu_opcodes[BPF_OP(insn->code)];
1403                         if (imm32 == 1)
1404                                 EMIT2(0xD1, add_1reg(b3, dst_reg));
1405                         else
1406                                 EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
1407                         break;
1408
1409                 case BPF_ALU | BPF_LSH | BPF_X:
1410                 case BPF_ALU | BPF_RSH | BPF_X:
1411                 case BPF_ALU | BPF_ARSH | BPF_X:
1412                 case BPF_ALU64 | BPF_LSH | BPF_X:
1413                 case BPF_ALU64 | BPF_RSH | BPF_X:
1414                 case BPF_ALU64 | BPF_ARSH | BPF_X:
1415                         /* BMI2 shifts aren't better when shift count is already in rcx */
1416                         if (boot_cpu_has(X86_FEATURE_BMI2) && src_reg != BPF_REG_4) {
1417                                 /* shrx/sarx/shlx dst_reg, dst_reg, src_reg */
1418                                 bool w = (BPF_CLASS(insn->code) == BPF_ALU64);
1419                                 u8 op;
1420
1421                                 switch (BPF_OP(insn->code)) {
1422                                 case BPF_LSH:
1423                                         op = 1; /* prefix 0x66 */
1424                                         break;
1425                                 case BPF_RSH:
1426                                         op = 3; /* prefix 0xf2 */
1427                                         break;
1428                                 case BPF_ARSH:
1429                                         op = 2; /* prefix 0xf3 */
1430                                         break;
1431                                 }
1432
1433                                 emit_shiftx(&prog, dst_reg, src_reg, w, op);
1434
1435                                 break;
1436                         }
1437
1438                         if (src_reg != BPF_REG_4) { /* common case */
1439                                 /* Check for bad case when dst_reg == rcx */
1440                                 if (dst_reg == BPF_REG_4) {
1441                                         /* mov r11, dst_reg */
1442                                         EMIT_mov(AUX_REG, dst_reg);
1443                                         dst_reg = AUX_REG;
1444                                 } else {
1445                                         EMIT1(0x51); /* push rcx */
1446                                 }
1447                                 /* mov rcx, src_reg */
1448                                 EMIT_mov(BPF_REG_4, src_reg);
1449                         }
1450
1451                         /* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
1452                         maybe_emit_1mod(&prog, dst_reg,
1453                                         BPF_CLASS(insn->code) == BPF_ALU64);
1454
1455                         b3 = simple_alu_opcodes[BPF_OP(insn->code)];
1456                         EMIT2(0xD3, add_1reg(b3, dst_reg));
1457
1458                         if (src_reg != BPF_REG_4) {
1459                                 if (insn->dst_reg == BPF_REG_4)
1460                                         /* mov dst_reg, r11 */
1461                                         EMIT_mov(insn->dst_reg, AUX_REG);
1462                                 else
1463                                         EMIT1(0x59); /* pop rcx */
1464                         }
1465
1466                         break;
1467
1468                 case BPF_ALU | BPF_END | BPF_FROM_BE:
1469                 case BPF_ALU64 | BPF_END | BPF_FROM_LE:
1470                         switch (imm32) {
1471                         case 16:
1472                                 /* Emit 'ror %ax, 8' to swap lower 2 bytes */
1473                                 EMIT1(0x66);
1474                                 if (is_ereg(dst_reg))
1475                                         EMIT1(0x41);
1476                                 EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
1477
1478                                 /* Emit 'movzwl eax, ax' */
1479                                 if (is_ereg(dst_reg))
1480                                         EMIT3(0x45, 0x0F, 0xB7);
1481                                 else
1482                                         EMIT2(0x0F, 0xB7);
1483                                 EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
1484                                 break;
1485                         case 32:
1486                                 /* Emit 'bswap eax' to swap lower 4 bytes */
1487                                 if (is_ereg(dst_reg))
1488                                         EMIT2(0x41, 0x0F);
1489                                 else
1490                                         EMIT1(0x0F);
1491                                 EMIT1(add_1reg(0xC8, dst_reg));
1492                                 break;
1493                         case 64:
1494                                 /* Emit 'bswap rax' to swap 8 bytes */
1495                                 EMIT3(add_1mod(0x48, dst_reg), 0x0F,
1496                                       add_1reg(0xC8, dst_reg));
1497                                 break;
1498                         }
1499                         break;
1500
1501                 case BPF_ALU | BPF_END | BPF_FROM_LE:
1502                         switch (imm32) {
1503                         case 16:
1504                                 /*
1505                                  * Emit 'movzwl eax, ax' to zero extend 16-bit
1506                                  * into 64 bit
1507                                  */
1508                                 if (is_ereg(dst_reg))
1509                                         EMIT3(0x45, 0x0F, 0xB7);
1510                                 else
1511                                         EMIT2(0x0F, 0xB7);
1512                                 EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
1513                                 break;
1514                         case 32:
1515                                 /* Emit 'mov eax, eax' to clear upper 32-bits */
1516                                 if (is_ereg(dst_reg))
1517                                         EMIT1(0x45);
1518                                 EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
1519                                 break;
1520                         case 64:
1521                                 /* nop */
1522                                 break;
1523                         }
1524                         break;
1525
1526                         /* speculation barrier */
1527                 case BPF_ST | BPF_NOSPEC:
1528                         EMIT_LFENCE();
1529                         break;
1530
1531                         /* ST: *(u8*)(dst_reg + off) = imm */
1532                 case BPF_ST | BPF_MEM | BPF_B:
1533                         if (is_ereg(dst_reg))
1534                                 EMIT2(0x41, 0xC6);
1535                         else
1536                                 EMIT1(0xC6);
1537                         goto st;
1538                 case BPF_ST | BPF_MEM | BPF_H:
1539                         if (is_ereg(dst_reg))
1540                                 EMIT3(0x66, 0x41, 0xC7);
1541                         else
1542                                 EMIT2(0x66, 0xC7);
1543                         goto st;
1544                 case BPF_ST | BPF_MEM | BPF_W:
1545                         if (is_ereg(dst_reg))
1546                                 EMIT2(0x41, 0xC7);
1547                         else
1548                                 EMIT1(0xC7);
1549                         goto st;
1550                 case BPF_ST | BPF_MEM | BPF_DW:
1551                         EMIT2(add_1mod(0x48, dst_reg), 0xC7);
1552
1553 st:                     if (is_imm8(insn->off))
1554                                 EMIT2(add_1reg(0x40, dst_reg), insn->off);
1555                         else
1556                                 EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
1557
1558                         EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
1559                         break;
1560
1561                         /* STX: *(u8*)(dst_reg + off) = src_reg */
1562                 case BPF_STX | BPF_MEM | BPF_B:
1563                 case BPF_STX | BPF_MEM | BPF_H:
1564                 case BPF_STX | BPF_MEM | BPF_W:
1565                 case BPF_STX | BPF_MEM | BPF_DW:
1566                         emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
1567                         break;
1568
1569                         /* LDX: dst_reg = *(u8*)(src_reg + off) */
1570                 case BPF_LDX | BPF_MEM | BPF_B:
1571                 case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1572                 case BPF_LDX | BPF_MEM | BPF_H:
1573                 case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1574                 case BPF_LDX | BPF_MEM | BPF_W:
1575                 case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1576                 case BPF_LDX | BPF_MEM | BPF_DW:
1577                 case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1578                         /* LDXS: dst_reg = *(s8*)(src_reg + off) */
1579                 case BPF_LDX | BPF_MEMSX | BPF_B:
1580                 case BPF_LDX | BPF_MEMSX | BPF_H:
1581                 case BPF_LDX | BPF_MEMSX | BPF_W:
1582                 case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
1583                 case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
1584                 case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
1585                         insn_off = insn->off;
1586
1587                         if (BPF_MODE(insn->code) == BPF_PROBE_MEM ||
1588                             BPF_MODE(insn->code) == BPF_PROBE_MEMSX) {
1589                                 /* Conservatively check that src_reg + insn->off is a kernel address:
1590                                  *   src_reg + insn->off >= TASK_SIZE_MAX + PAGE_SIZE
1591                                  * src_reg is used as scratch for src_reg += insn->off and restored
1592                                  * after emit_ldx if necessary
1593                                  */
1594
1595                                 u64 limit = TASK_SIZE_MAX + PAGE_SIZE;
1596                                 u8 *end_of_jmp;
1597
1598                                 /* At end of these emitted checks, insn->off will have been added
1599                                  * to src_reg, so no need to do relative load with insn->off offset
1600                                  */
1601                                 insn_off = 0;
1602
1603                                 /* movabsq r11, limit */
1604                                 EMIT2(add_1mod(0x48, AUX_REG), add_1reg(0xB8, AUX_REG));
1605                                 EMIT((u32)limit, 4);
1606                                 EMIT(limit >> 32, 4);
1607
1608                                 if (insn->off) {
1609                                         /* add src_reg, insn->off */
1610                                         maybe_emit_1mod(&prog, src_reg, true);
1611                                         EMIT2_off32(0x81, add_1reg(0xC0, src_reg), insn->off);
1612                                 }
1613
1614                                 /* cmp src_reg, r11 */
1615                                 maybe_emit_mod(&prog, src_reg, AUX_REG, true);
1616                                 EMIT2(0x39, add_2reg(0xC0, src_reg, AUX_REG));
1617
1618                                 /* if unsigned '>=', goto load */
1619                                 EMIT2(X86_JAE, 0);
1620                                 end_of_jmp = prog;
1621
1622                                 /* xor dst_reg, dst_reg */
1623                                 emit_mov_imm32(&prog, false, dst_reg, 0);
1624                                 /* jmp byte_after_ldx */
1625                                 EMIT2(0xEB, 0);
1626
1627                                 /* populate jmp_offset for JAE above to jump to start_of_ldx */
1628                                 start_of_ldx = prog;
1629                                 end_of_jmp[-1] = start_of_ldx - end_of_jmp;
1630                         }
1631                         if (BPF_MODE(insn->code) == BPF_PROBE_MEMSX ||
1632                             BPF_MODE(insn->code) == BPF_MEMSX)
1633                                 emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
1634                         else
1635                                 emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
1636                         if (BPF_MODE(insn->code) == BPF_PROBE_MEM ||
1637                             BPF_MODE(insn->code) == BPF_PROBE_MEMSX) {
1638                                 struct exception_table_entry *ex;
1639                                 u8 *_insn = image + proglen + (start_of_ldx - temp);
1640                                 s64 delta;
1641
1642                                 /* populate jmp_offset for JMP above */
1643                                 start_of_ldx[-1] = prog - start_of_ldx;
1644
1645                                 if (insn->off && src_reg != dst_reg) {
1646                                         /* sub src_reg, insn->off
1647                                          * Restore src_reg after "add src_reg, insn->off" in prev
1648                                          * if statement. But if src_reg == dst_reg, emit_ldx
1649                                          * above already clobbered src_reg, so no need to restore.
1650                                          * If add src_reg, insn->off was unnecessary, no need to
1651                                          * restore either.
1652                                          */
1653                                         maybe_emit_1mod(&prog, src_reg, true);
1654                                         EMIT2_off32(0x81, add_1reg(0xE8, src_reg), insn->off);
1655                                 }
1656
1657                                 if (!bpf_prog->aux->extable)
1658                                         break;
1659
1660                                 if (excnt >= bpf_prog->aux->num_exentries) {
1661                                         pr_err("ex gen bug\n");
1662                                         return -EFAULT;
1663                                 }
1664                                 ex = &bpf_prog->aux->extable[excnt++];
1665
1666                                 delta = _insn - (u8 *)&ex->insn;
1667                                 if (!is_simm32(delta)) {
1668                                         pr_err("extable->insn doesn't fit into 32-bit\n");
1669                                         return -EFAULT;
1670                                 }
1671                                 /* switch ex to rw buffer for writes */
1672                                 ex = (void *)rw_image + ((void *)ex - (void *)image);
1673
1674                                 ex->insn = delta;
1675
1676                                 ex->data = EX_TYPE_BPF;
1677
1678                                 if (dst_reg > BPF_REG_9) {
1679                                         pr_err("verifier error\n");
1680                                         return -EFAULT;
1681                                 }
1682                                 /*
1683                                  * Compute size of x86 insn and its target dest x86 register.
1684                                  * ex_handler_bpf() will use lower 8 bits to adjust
1685                                  * pt_regs->ip to jump over this x86 instruction
1686                                  * and upper bits to figure out which pt_regs to zero out.
1687                                  * End result: x86 insn "mov rbx, qword ptr [rax+0x14]"
1688                                  * of 4 bytes will be ignored and rbx will be zero inited.
1689                                  */
1690                                 ex->fixup = (prog - start_of_ldx) | (reg2pt_regs[dst_reg] << 8);
1691                         }
1692                         break;
1693
1694                 case BPF_STX | BPF_ATOMIC | BPF_W:
1695                 case BPF_STX | BPF_ATOMIC | BPF_DW:
1696                         if (insn->imm == (BPF_AND | BPF_FETCH) ||
1697                             insn->imm == (BPF_OR | BPF_FETCH) ||
1698                             insn->imm == (BPF_XOR | BPF_FETCH)) {
1699                                 bool is64 = BPF_SIZE(insn->code) == BPF_DW;
1700                                 u32 real_src_reg = src_reg;
1701                                 u32 real_dst_reg = dst_reg;
1702                                 u8 *branch_target;
1703
1704                                 /*
1705                                  * Can't be implemented with a single x86 insn.
1706                                  * Need to do a CMPXCHG loop.
1707                                  */
1708
1709                                 /* Will need RAX as a CMPXCHG operand so save R0 */
1710                                 emit_mov_reg(&prog, true, BPF_REG_AX, BPF_REG_0);
1711                                 if (src_reg == BPF_REG_0)
1712                                         real_src_reg = BPF_REG_AX;
1713                                 if (dst_reg == BPF_REG_0)
1714                                         real_dst_reg = BPF_REG_AX;
1715
1716                                 branch_target = prog;
1717                                 /* Load old value */
1718                                 emit_ldx(&prog, BPF_SIZE(insn->code),
1719                                          BPF_REG_0, real_dst_reg, insn->off);
1720                                 /*
1721                                  * Perform the (commutative) operation locally,
1722                                  * put the result in the AUX_REG.
1723                                  */
1724                                 emit_mov_reg(&prog, is64, AUX_REG, BPF_REG_0);
1725                                 maybe_emit_mod(&prog, AUX_REG, real_src_reg, is64);
1726                                 EMIT2(simple_alu_opcodes[BPF_OP(insn->imm)],
1727                                       add_2reg(0xC0, AUX_REG, real_src_reg));
1728                                 /* Attempt to swap in new value */
1729                                 err = emit_atomic(&prog, BPF_CMPXCHG,
1730                                                   real_dst_reg, AUX_REG,
1731                                                   insn->off,
1732                                                   BPF_SIZE(insn->code));
1733                                 if (WARN_ON(err))
1734                                         return err;
1735                                 /*
1736                                  * ZF tells us whether we won the race. If it's
1737                                  * cleared we need to try again.
1738                                  */
1739                                 EMIT2(X86_JNE, -(prog - branch_target) - 2);
1740                                 /* Return the pre-modification value */
1741                                 emit_mov_reg(&prog, is64, real_src_reg, BPF_REG_0);
1742                                 /* Restore R0 after clobbering RAX */
1743                                 emit_mov_reg(&prog, true, BPF_REG_0, BPF_REG_AX);
1744                                 break;
1745                         }
1746
1747                         err = emit_atomic(&prog, insn->imm, dst_reg, src_reg,
1748                                           insn->off, BPF_SIZE(insn->code));
1749                         if (err)
1750                                 return err;
1751                         break;
1752
1753                         /* call */
1754                 case BPF_JMP | BPF_CALL: {
1755                         int offs;
1756
1757                         func = (u8 *) __bpf_call_base + imm32;
1758                         if (tail_call_reachable) {
1759                                 RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth);
1760                                 if (!imm32)
1761                                         return -EINVAL;
1762                                 offs = 7 + x86_call_depth_emit_accounting(&prog, func);
1763                         } else {
1764                                 if (!imm32)
1765                                         return -EINVAL;
1766                                 offs = x86_call_depth_emit_accounting(&prog, func);
1767                         }
1768                         if (emit_call(&prog, func, image + addrs[i - 1] + offs))
1769                                 return -EINVAL;
1770                         break;
1771                 }
1772
1773                 case BPF_JMP | BPF_TAIL_CALL:
1774                         if (imm32)
1775                                 emit_bpf_tail_call_direct(bpf_prog,
1776                                                           &bpf_prog->aux->poke_tab[imm32 - 1],
1777                                                           &prog, image + addrs[i - 1],
1778                                                           callee_regs_used,
1779                                                           bpf_prog->aux->stack_depth,
1780                                                           ctx);
1781                         else
1782                                 emit_bpf_tail_call_indirect(bpf_prog,
1783                                                             &prog,
1784                                                             callee_regs_used,
1785                                                             bpf_prog->aux->stack_depth,
1786                                                             image + addrs[i - 1],
1787                                                             ctx);
1788                         break;
1789
1790                         /* cond jump */
1791                 case BPF_JMP | BPF_JEQ | BPF_X:
1792                 case BPF_JMP | BPF_JNE | BPF_X:
1793                 case BPF_JMP | BPF_JGT | BPF_X:
1794                 case BPF_JMP | BPF_JLT | BPF_X:
1795                 case BPF_JMP | BPF_JGE | BPF_X:
1796                 case BPF_JMP | BPF_JLE | BPF_X:
1797                 case BPF_JMP | BPF_JSGT | BPF_X:
1798                 case BPF_JMP | BPF_JSLT | BPF_X:
1799                 case BPF_JMP | BPF_JSGE | BPF_X:
1800                 case BPF_JMP | BPF_JSLE | BPF_X:
1801                 case BPF_JMP32 | BPF_JEQ | BPF_X:
1802                 case BPF_JMP32 | BPF_JNE | BPF_X:
1803                 case BPF_JMP32 | BPF_JGT | BPF_X:
1804                 case BPF_JMP32 | BPF_JLT | BPF_X:
1805                 case BPF_JMP32 | BPF_JGE | BPF_X:
1806                 case BPF_JMP32 | BPF_JLE | BPF_X:
1807                 case BPF_JMP32 | BPF_JSGT | BPF_X:
1808                 case BPF_JMP32 | BPF_JSLT | BPF_X:
1809                 case BPF_JMP32 | BPF_JSGE | BPF_X:
1810                 case BPF_JMP32 | BPF_JSLE | BPF_X:
1811                         /* cmp dst_reg, src_reg */
1812                         maybe_emit_mod(&prog, dst_reg, src_reg,
1813                                        BPF_CLASS(insn->code) == BPF_JMP);
1814                         EMIT2(0x39, add_2reg(0xC0, dst_reg, src_reg));
1815                         goto emit_cond_jmp;
1816
1817                 case BPF_JMP | BPF_JSET | BPF_X:
1818                 case BPF_JMP32 | BPF_JSET | BPF_X:
1819                         /* test dst_reg, src_reg */
1820                         maybe_emit_mod(&prog, dst_reg, src_reg,
1821                                        BPF_CLASS(insn->code) == BPF_JMP);
1822                         EMIT2(0x85, add_2reg(0xC0, dst_reg, src_reg));
1823                         goto emit_cond_jmp;
1824
1825                 case BPF_JMP | BPF_JSET | BPF_K:
1826                 case BPF_JMP32 | BPF_JSET | BPF_K:
1827                         /* test dst_reg, imm32 */
1828                         maybe_emit_1mod(&prog, dst_reg,
1829                                         BPF_CLASS(insn->code) == BPF_JMP);
1830                         EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
1831                         goto emit_cond_jmp;
1832
1833                 case BPF_JMP | BPF_JEQ | BPF_K:
1834                 case BPF_JMP | BPF_JNE | BPF_K:
1835                 case BPF_JMP | BPF_JGT | BPF_K:
1836                 case BPF_JMP | BPF_JLT | BPF_K:
1837                 case BPF_JMP | BPF_JGE | BPF_K:
1838                 case BPF_JMP | BPF_JLE | BPF_K:
1839                 case BPF_JMP | BPF_JSGT | BPF_K:
1840                 case BPF_JMP | BPF_JSLT | BPF_K:
1841                 case BPF_JMP | BPF_JSGE | BPF_K:
1842                 case BPF_JMP | BPF_JSLE | BPF_K:
1843                 case BPF_JMP32 | BPF_JEQ | BPF_K:
1844                 case BPF_JMP32 | BPF_JNE | BPF_K:
1845                 case BPF_JMP32 | BPF_JGT | BPF_K:
1846                 case BPF_JMP32 | BPF_JLT | BPF_K:
1847                 case BPF_JMP32 | BPF_JGE | BPF_K:
1848                 case BPF_JMP32 | BPF_JLE | BPF_K:
1849                 case BPF_JMP32 | BPF_JSGT | BPF_K:
1850                 case BPF_JMP32 | BPF_JSLT | BPF_K:
1851                 case BPF_JMP32 | BPF_JSGE | BPF_K:
1852                 case BPF_JMP32 | BPF_JSLE | BPF_K:
1853                         /* test dst_reg, dst_reg to save one extra byte */
1854                         if (imm32 == 0) {
1855                                 maybe_emit_mod(&prog, dst_reg, dst_reg,
1856                                                BPF_CLASS(insn->code) == BPF_JMP);
1857                                 EMIT2(0x85, add_2reg(0xC0, dst_reg, dst_reg));
1858                                 goto emit_cond_jmp;
1859                         }
1860
1861                         /* cmp dst_reg, imm8/32 */
1862                         maybe_emit_1mod(&prog, dst_reg,
1863                                         BPF_CLASS(insn->code) == BPF_JMP);
1864
1865                         if (is_imm8(imm32))
1866                                 EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
1867                         else
1868                                 EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
1869
1870 emit_cond_jmp:          /* Convert BPF opcode to x86 */
1871                         switch (BPF_OP(insn->code)) {
1872                         case BPF_JEQ:
1873                                 jmp_cond = X86_JE;
1874                                 break;
1875                         case BPF_JSET:
1876                         case BPF_JNE:
1877                                 jmp_cond = X86_JNE;
1878                                 break;
1879                         case BPF_JGT:
1880                                 /* GT is unsigned '>', JA in x86 */
1881                                 jmp_cond = X86_JA;
1882                                 break;
1883                         case BPF_JLT:
1884                                 /* LT is unsigned '<', JB in x86 */
1885                                 jmp_cond = X86_JB;
1886                                 break;
1887                         case BPF_JGE:
1888                                 /* GE is unsigned '>=', JAE in x86 */
1889                                 jmp_cond = X86_JAE;
1890                                 break;
1891                         case BPF_JLE:
1892                                 /* LE is unsigned '<=', JBE in x86 */
1893                                 jmp_cond = X86_JBE;
1894                                 break;
1895                         case BPF_JSGT:
1896                                 /* Signed '>', GT in x86 */
1897                                 jmp_cond = X86_JG;
1898                                 break;
1899                         case BPF_JSLT:
1900                                 /* Signed '<', LT in x86 */
1901                                 jmp_cond = X86_JL;
1902                                 break;
1903                         case BPF_JSGE:
1904                                 /* Signed '>=', GE in x86 */
1905                                 jmp_cond = X86_JGE;
1906                                 break;
1907                         case BPF_JSLE:
1908                                 /* Signed '<=', LE in x86 */
1909                                 jmp_cond = X86_JLE;
1910                                 break;
1911                         default: /* to silence GCC warning */
1912                                 return -EFAULT;
1913                         }
1914                         jmp_offset = addrs[i + insn->off] - addrs[i];
1915                         if (is_imm8(jmp_offset)) {
1916                                 if (jmp_padding) {
1917                                         /* To keep the jmp_offset valid, the extra bytes are
1918                                          * padded before the jump insn, so we subtract the
1919                                          * 2 bytes of jmp_cond insn from INSN_SZ_DIFF.
1920                                          *
1921                                          * If the previous pass already emits an imm8
1922                                          * jmp_cond, then this BPF insn won't shrink, so
1923                                          * "nops" is 0.
1924                                          *
1925                                          * On the other hand, if the previous pass emits an
1926                                          * imm32 jmp_cond, the extra 4 bytes(*) is padded to
1927                                          * keep the image from shrinking further.
1928                                          *
1929                                          * (*) imm32 jmp_cond is 6 bytes, and imm8 jmp_cond
1930                                          *     is 2 bytes, so the size difference is 4 bytes.
1931                                          */
1932                                         nops = INSN_SZ_DIFF - 2;
1933                                         if (nops != 0 && nops != 4) {
1934                                                 pr_err("unexpected jmp_cond padding: %d bytes\n",
1935                                                        nops);
1936                                                 return -EFAULT;
1937                                         }
1938                                         emit_nops(&prog, nops);
1939                                 }
1940                                 EMIT2(jmp_cond, jmp_offset);
1941                         } else if (is_simm32(jmp_offset)) {
1942                                 EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
1943                         } else {
1944                                 pr_err("cond_jmp gen bug %llx\n", jmp_offset);
1945                                 return -EFAULT;
1946                         }
1947
1948                         break;
1949
1950                 case BPF_JMP | BPF_JA:
1951                 case BPF_JMP32 | BPF_JA:
1952                         if (BPF_CLASS(insn->code) == BPF_JMP) {
1953                                 if (insn->off == -1)
1954                                         /* -1 jmp instructions will always jump
1955                                          * backwards two bytes. Explicitly handling
1956                                          * this case avoids wasting too many passes
1957                                          * when there are long sequences of replaced
1958                                          * dead code.
1959                                          */
1960                                         jmp_offset = -2;
1961                                 else
1962                                         jmp_offset = addrs[i + insn->off] - addrs[i];
1963                         } else {
1964                                 if (insn->imm == -1)
1965                                         jmp_offset = -2;
1966                                 else
1967                                         jmp_offset = addrs[i + insn->imm] - addrs[i];
1968                         }
1969
1970                         if (!jmp_offset) {
1971                                 /*
1972                                  * If jmp_padding is enabled, the extra nops will
1973                                  * be inserted. Otherwise, optimize out nop jumps.
1974                                  */
1975                                 if (jmp_padding) {
1976                                         /* There are 3 possible conditions.
1977                                          * (1) This BPF_JA is already optimized out in
1978                                          *     the previous run, so there is no need
1979                                          *     to pad any extra byte (0 byte).
1980                                          * (2) The previous pass emits an imm8 jmp,
1981                                          *     so we pad 2 bytes to match the previous
1982                                          *     insn size.
1983                                          * (3) Similarly, the previous pass emits an
1984                                          *     imm32 jmp, and 5 bytes is padded.
1985                                          */
1986                                         nops = INSN_SZ_DIFF;
1987                                         if (nops != 0 && nops != 2 && nops != 5) {
1988                                                 pr_err("unexpected nop jump padding: %d bytes\n",
1989                                                        nops);
1990                                                 return -EFAULT;
1991                                         }
1992                                         emit_nops(&prog, nops);
1993                                 }
1994                                 break;
1995                         }
1996 emit_jmp:
1997                         if (is_imm8(jmp_offset)) {
1998                                 if (jmp_padding) {
1999                                         /* To avoid breaking jmp_offset, the extra bytes
2000                                          * are padded before the actual jmp insn, so
2001                                          * 2 bytes is subtracted from INSN_SZ_DIFF.
2002                                          *
2003                                          * If the previous pass already emits an imm8
2004                                          * jmp, there is nothing to pad (0 byte).
2005                                          *
2006                                          * If it emits an imm32 jmp (5 bytes) previously
2007                                          * and now an imm8 jmp (2 bytes), then we pad
2008                                          * (5 - 2 = 3) bytes to stop the image from
2009                                          * shrinking further.
2010                                          */
2011                                         nops = INSN_SZ_DIFF - 2;
2012                                         if (nops != 0 && nops != 3) {
2013                                                 pr_err("unexpected jump padding: %d bytes\n",
2014                                                        nops);
2015                                                 return -EFAULT;
2016                                         }
2017                                         emit_nops(&prog, INSN_SZ_DIFF - 2);
2018                                 }
2019                                 EMIT2(0xEB, jmp_offset);
2020                         } else if (is_simm32(jmp_offset)) {
2021                                 EMIT1_off32(0xE9, jmp_offset);
2022                         } else {
2023                                 pr_err("jmp gen bug %llx\n", jmp_offset);
2024                                 return -EFAULT;
2025                         }
2026                         break;
2027
2028                 case BPF_JMP | BPF_EXIT:
2029                         if (seen_exit) {
2030                                 jmp_offset = ctx->cleanup_addr - addrs[i];
2031                                 goto emit_jmp;
2032                         }
2033                         seen_exit = true;
2034                         /* Update cleanup_addr */
2035                         ctx->cleanup_addr = proglen;
2036                         if (bpf_prog->aux->exception_boundary) {
2037                                 pop_callee_regs(&prog, all_callee_regs_used);
2038                                 pop_r12(&prog);
2039                         } else {
2040                                 pop_callee_regs(&prog, callee_regs_used);
2041                         }
2042                         EMIT1(0xC9);         /* leave */
2043                         emit_return(&prog, image + addrs[i - 1] + (prog - temp));
2044                         break;
2045
2046                 default:
2047                         /*
2048                          * By design x86-64 JIT should support all BPF instructions.
2049                          * This error will be seen if new instruction was added
2050                          * to the interpreter, but not to the JIT, or if there is
2051                          * junk in bpf_prog.
2052                          */
2053                         pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
2054                         return -EINVAL;
2055                 }
2056
2057                 ilen = prog - temp;
2058                 if (ilen > BPF_MAX_INSN_SIZE) {
2059                         pr_err("bpf_jit: fatal insn size error\n");
2060                         return -EFAULT;
2061                 }
2062
2063                 if (image) {
2064                         /*
2065                          * When populating the image, assert that:
2066                          *
2067                          *  i) We do not write beyond the allocated space, and
2068                          * ii) addrs[i] did not change from the prior run, in order
2069                          *     to validate assumptions made for computing branch
2070                          *     displacements.
2071                          */
2072                         if (unlikely(proglen + ilen > oldproglen ||
2073                                      proglen + ilen != addrs[i])) {
2074                                 pr_err("bpf_jit: fatal error\n");
2075                                 return -EFAULT;
2076                         }
2077                         memcpy(rw_image + proglen, temp, ilen);
2078                 }
2079                 proglen += ilen;
2080                 addrs[i] = proglen;
2081                 prog = temp;
2082         }
2083
2084         if (image && excnt != bpf_prog->aux->num_exentries) {
2085                 pr_err("extable is not populated\n");
2086                 return -EFAULT;
2087         }
2088         return proglen;
2089 }
2090
2091 static void clean_stack_garbage(const struct btf_func_model *m,
2092                                 u8 **pprog, int nr_stack_slots,
2093                                 int stack_size)
2094 {
2095         int arg_size, off;
2096         u8 *prog;
2097
2098         /* Generally speaking, the compiler will pass the arguments
2099          * on-stack with "push" instruction, which will take 8-byte
2100          * on the stack. In this case, there won't be garbage values
2101          * while we copy the arguments from origin stack frame to current
2102          * in BPF_DW.
2103          *
2104          * However, sometimes the compiler will only allocate 4-byte on
2105          * the stack for the arguments. For now, this case will only
2106          * happen if there is only one argument on-stack and its size
2107          * not more than 4 byte. In this case, there will be garbage
2108          * values on the upper 4-byte where we store the argument on
2109          * current stack frame.
2110          *
2111          * arguments on origin stack:
2112          *
2113          * stack_arg_1(4-byte) xxx(4-byte)
2114          *
2115          * what we copy:
2116          *
2117          * stack_arg_1(8-byte): stack_arg_1(origin) xxx
2118          *
2119          * and the xxx is the garbage values which we should clean here.
2120          */
2121         if (nr_stack_slots != 1)
2122                 return;
2123
2124         /* the size of the last argument */
2125         arg_size = m->arg_size[m->nr_args - 1];
2126         if (arg_size <= 4) {
2127                 off = -(stack_size - 4);
2128                 prog = *pprog;
2129                 /* mov DWORD PTR [rbp + off], 0 */
2130                 if (!is_imm8(off))
2131                         EMIT2_off32(0xC7, 0x85, off);
2132                 else
2133                         EMIT3(0xC7, 0x45, off);
2134                 EMIT(0, 4);
2135                 *pprog = prog;
2136         }
2137 }
2138
2139 /* get the count of the regs that are used to pass arguments */
2140 static int get_nr_used_regs(const struct btf_func_model *m)
2141 {
2142         int i, arg_regs, nr_used_regs = 0;
2143
2144         for (i = 0; i < min_t(int, m->nr_args, MAX_BPF_FUNC_ARGS); i++) {
2145                 arg_regs = (m->arg_size[i] + 7) / 8;
2146                 if (nr_used_regs + arg_regs <= 6)
2147                         nr_used_regs += arg_regs;
2148
2149                 if (nr_used_regs >= 6)
2150                         break;
2151         }
2152
2153         return nr_used_regs;
2154 }
2155
2156 static void save_args(const struct btf_func_model *m, u8 **prog,
2157                       int stack_size, bool for_call_origin)
2158 {
2159         int arg_regs, first_off = 0, nr_regs = 0, nr_stack_slots = 0;
2160         int i, j;
2161
2162         /* Store function arguments to stack.
2163          * For a function that accepts two pointers the sequence will be:
2164          * mov QWORD PTR [rbp-0x10],rdi
2165          * mov QWORD PTR [rbp-0x8],rsi
2166          */
2167         for (i = 0; i < min_t(int, m->nr_args, MAX_BPF_FUNC_ARGS); i++) {
2168                 arg_regs = (m->arg_size[i] + 7) / 8;
2169
2170                 /* According to the research of Yonghong, struct members
2171                  * should be all in register or all on the stack.
2172                  * Meanwhile, the compiler will pass the argument on regs
2173                  * if the remaining regs can hold the argument.
2174                  *
2175                  * Disorder of the args can happen. For example:
2176                  *
2177                  * struct foo_struct {
2178                  *     long a;
2179                  *     int b;
2180                  * };
2181                  * int foo(char, char, char, char, char, struct foo_struct,
2182                  *         char);
2183                  *
2184                  * the arg1-5,arg7 will be passed by regs, and arg6 will
2185                  * by stack.
2186                  */
2187                 if (nr_regs + arg_regs > 6) {
2188                         /* copy function arguments from origin stack frame
2189                          * into current stack frame.
2190                          *
2191                          * The starting address of the arguments on-stack
2192                          * is:
2193                          *   rbp + 8(push rbp) +
2194                          *   8(return addr of origin call) +
2195                          *   8(return addr of the caller)
2196                          * which means: rbp + 24
2197                          */
2198                         for (j = 0; j < arg_regs; j++) {
2199                                 emit_ldx(prog, BPF_DW, BPF_REG_0, BPF_REG_FP,
2200                                          nr_stack_slots * 8 + 0x18);
2201                                 emit_stx(prog, BPF_DW, BPF_REG_FP, BPF_REG_0,
2202                                          -stack_size);
2203
2204                                 if (!nr_stack_slots)
2205                                         first_off = stack_size;
2206                                 stack_size -= 8;
2207                                 nr_stack_slots++;
2208                         }
2209                 } else {
2210                         /* Only copy the arguments on-stack to current
2211                          * 'stack_size' and ignore the regs, used to
2212                          * prepare the arguments on-stack for orign call.
2213                          */
2214                         if (for_call_origin) {
2215                                 nr_regs += arg_regs;
2216                                 continue;
2217                         }
2218
2219                         /* copy the arguments from regs into stack */
2220                         for (j = 0; j < arg_regs; j++) {
2221                                 emit_stx(prog, BPF_DW, BPF_REG_FP,
2222                                          nr_regs == 5 ? X86_REG_R9 : BPF_REG_1 + nr_regs,
2223                                          -stack_size);
2224                                 stack_size -= 8;
2225                                 nr_regs++;
2226                         }
2227                 }
2228         }
2229
2230         clean_stack_garbage(m, prog, nr_stack_slots, first_off);
2231 }
2232
2233 static void restore_regs(const struct btf_func_model *m, u8 **prog,
2234                          int stack_size)
2235 {
2236         int i, j, arg_regs, nr_regs = 0;
2237
2238         /* Restore function arguments from stack.
2239          * For a function that accepts two pointers the sequence will be:
2240          * EMIT4(0x48, 0x8B, 0x7D, 0xF0); mov rdi,QWORD PTR [rbp-0x10]
2241          * EMIT4(0x48, 0x8B, 0x75, 0xF8); mov rsi,QWORD PTR [rbp-0x8]
2242          *
2243          * The logic here is similar to what we do in save_args()
2244          */
2245         for (i = 0; i < min_t(int, m->nr_args, MAX_BPF_FUNC_ARGS); i++) {
2246                 arg_regs = (m->arg_size[i] + 7) / 8;
2247                 if (nr_regs + arg_regs <= 6) {
2248                         for (j = 0; j < arg_regs; j++) {
2249                                 emit_ldx(prog, BPF_DW,
2250                                          nr_regs == 5 ? X86_REG_R9 : BPF_REG_1 + nr_regs,
2251                                          BPF_REG_FP,
2252                                          -stack_size);
2253                                 stack_size -= 8;
2254                                 nr_regs++;
2255                         }
2256                 } else {
2257                         stack_size -= 8 * arg_regs;
2258                 }
2259
2260                 if (nr_regs >= 6)
2261                         break;
2262         }
2263 }
2264
2265 static int invoke_bpf_prog(const struct btf_func_model *m, u8 **pprog,
2266                            struct bpf_tramp_link *l, int stack_size,
2267                            int run_ctx_off, bool save_ret,
2268                            void *image, void *rw_image)
2269 {
2270         u8 *prog = *pprog;
2271         u8 *jmp_insn;
2272         int ctx_cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
2273         struct bpf_prog *p = l->link.prog;
2274         u64 cookie = l->cookie;
2275
2276         /* mov rdi, cookie */
2277         emit_mov_imm64(&prog, BPF_REG_1, (long) cookie >> 32, (u32) (long) cookie);
2278
2279         /* Prepare struct bpf_tramp_run_ctx.
2280          *
2281          * bpf_tramp_run_ctx is already preserved by
2282          * arch_prepare_bpf_trampoline().
2283          *
2284          * mov QWORD PTR [rbp - run_ctx_off + ctx_cookie_off], rdi
2285          */
2286         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_1, -run_ctx_off + ctx_cookie_off);
2287
2288         /* arg1: mov rdi, progs[i] */
2289         emit_mov_imm64(&prog, BPF_REG_1, (long) p >> 32, (u32) (long) p);
2290         /* arg2: lea rsi, [rbp - ctx_cookie_off] */
2291         if (!is_imm8(-run_ctx_off))
2292                 EMIT3_off32(0x48, 0x8D, 0xB5, -run_ctx_off);
2293         else
2294                 EMIT4(0x48, 0x8D, 0x75, -run_ctx_off);
2295
2296         if (emit_rsb_call(&prog, bpf_trampoline_enter(p), image + (prog - (u8 *)rw_image)))
2297                 return -EINVAL;
2298         /* remember prog start time returned by __bpf_prog_enter */
2299         emit_mov_reg(&prog, true, BPF_REG_6, BPF_REG_0);
2300
2301         /* if (__bpf_prog_enter*(prog) == 0)
2302          *      goto skip_exec_of_prog;
2303          */
2304         EMIT3(0x48, 0x85, 0xC0);  /* test rax,rax */
2305         /* emit 2 nops that will be replaced with JE insn */
2306         jmp_insn = prog;
2307         emit_nops(&prog, 2);
2308
2309         /* arg1: lea rdi, [rbp - stack_size] */
2310         if (!is_imm8(-stack_size))
2311                 EMIT3_off32(0x48, 0x8D, 0xBD, -stack_size);
2312         else
2313                 EMIT4(0x48, 0x8D, 0x7D, -stack_size);
2314         /* arg2: progs[i]->insnsi for interpreter */
2315         if (!p->jited)
2316                 emit_mov_imm64(&prog, BPF_REG_2,
2317                                (long) p->insnsi >> 32,
2318                                (u32) (long) p->insnsi);
2319         /* call JITed bpf program or interpreter */
2320         if (emit_rsb_call(&prog, p->bpf_func, image + (prog - (u8 *)rw_image)))
2321                 return -EINVAL;
2322
2323         /*
2324          * BPF_TRAMP_MODIFY_RETURN trampolines can modify the return
2325          * of the previous call which is then passed on the stack to
2326          * the next BPF program.
2327          *
2328          * BPF_TRAMP_FENTRY trampoline may need to return the return
2329          * value of BPF_PROG_TYPE_STRUCT_OPS prog.
2330          */
2331         if (save_ret)
2332                 emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
2333
2334         /* replace 2 nops with JE insn, since jmp target is known */
2335         jmp_insn[0] = X86_JE;
2336         jmp_insn[1] = prog - jmp_insn - 2;
2337
2338         /* arg1: mov rdi, progs[i] */
2339         emit_mov_imm64(&prog, BPF_REG_1, (long) p >> 32, (u32) (long) p);
2340         /* arg2: mov rsi, rbx <- start time in nsec */
2341         emit_mov_reg(&prog, true, BPF_REG_2, BPF_REG_6);
2342         /* arg3: lea rdx, [rbp - run_ctx_off] */
2343         if (!is_imm8(-run_ctx_off))
2344                 EMIT3_off32(0x48, 0x8D, 0x95, -run_ctx_off);
2345         else
2346                 EMIT4(0x48, 0x8D, 0x55, -run_ctx_off);
2347         if (emit_rsb_call(&prog, bpf_trampoline_exit(p), image + (prog - (u8 *)rw_image)))
2348                 return -EINVAL;
2349
2350         *pprog = prog;
2351         return 0;
2352 }
2353
2354 static void emit_align(u8 **pprog, u32 align)
2355 {
2356         u8 *target, *prog = *pprog;
2357
2358         target = PTR_ALIGN(prog, align);
2359         if (target != prog)
2360                 emit_nops(&prog, target - prog);
2361
2362         *pprog = prog;
2363 }
2364
2365 static int emit_cond_near_jump(u8 **pprog, void *func, void *ip, u8 jmp_cond)
2366 {
2367         u8 *prog = *pprog;
2368         s64 offset;
2369
2370         offset = func - (ip + 2 + 4);
2371         if (!is_simm32(offset)) {
2372                 pr_err("Target %p is out of range\n", func);
2373                 return -EINVAL;
2374         }
2375         EMIT2_off32(0x0F, jmp_cond + 0x10, offset);
2376         *pprog = prog;
2377         return 0;
2378 }
2379
2380 static int invoke_bpf(const struct btf_func_model *m, u8 **pprog,
2381                       struct bpf_tramp_links *tl, int stack_size,
2382                       int run_ctx_off, bool save_ret,
2383                       void *image, void *rw_image)
2384 {
2385         int i;
2386         u8 *prog = *pprog;
2387
2388         for (i = 0; i < tl->nr_links; i++) {
2389                 if (invoke_bpf_prog(m, &prog, tl->links[i], stack_size,
2390                                     run_ctx_off, save_ret, image, rw_image))
2391                         return -EINVAL;
2392         }
2393         *pprog = prog;
2394         return 0;
2395 }
2396
2397 static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
2398                               struct bpf_tramp_links *tl, int stack_size,
2399                               int run_ctx_off, u8 **branches,
2400                               void *image, void *rw_image)
2401 {
2402         u8 *prog = *pprog;
2403         int i;
2404
2405         /* The first fmod_ret program will receive a garbage return value.
2406          * Set this to 0 to avoid confusing the program.
2407          */
2408         emit_mov_imm32(&prog, false, BPF_REG_0, 0);
2409         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
2410         for (i = 0; i < tl->nr_links; i++) {
2411                 if (invoke_bpf_prog(m, &prog, tl->links[i], stack_size, run_ctx_off, true,
2412                                     image, rw_image))
2413                         return -EINVAL;
2414
2415                 /* mod_ret prog stored return value into [rbp - 8]. Emit:
2416                  * if (*(u64 *)(rbp - 8) !=  0)
2417                  *      goto do_fexit;
2418                  */
2419                 /* cmp QWORD PTR [rbp - 0x8], 0x0 */
2420                 EMIT4(0x48, 0x83, 0x7d, 0xf8); EMIT1(0x00);
2421
2422                 /* Save the location of the branch and Generate 6 nops
2423                  * (4 bytes for an offset and 2 bytes for the jump) These nops
2424                  * are replaced with a conditional jump once do_fexit (i.e. the
2425                  * start of the fexit invocation) is finalized.
2426                  */
2427                 branches[i] = prog;
2428                 emit_nops(&prog, 4 + 2);
2429         }
2430
2431         *pprog = prog;
2432         return 0;
2433 }
2434
2435 /* Example:
2436  * __be16 eth_type_trans(struct sk_buff *skb, struct net_device *dev);
2437  * its 'struct btf_func_model' will be nr_args=2
2438  * The assembly code when eth_type_trans is executing after trampoline:
2439  *
2440  * push rbp
2441  * mov rbp, rsp
2442  * sub rsp, 16                     // space for skb and dev
2443  * push rbx                        // temp regs to pass start time
2444  * mov qword ptr [rbp - 16], rdi   // save skb pointer to stack
2445  * mov qword ptr [rbp - 8], rsi    // save dev pointer to stack
2446  * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
2447  * mov rbx, rax                    // remember start time in bpf stats are enabled
2448  * lea rdi, [rbp - 16]             // R1==ctx of bpf prog
2449  * call addr_of_jited_FENTRY_prog
2450  * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
2451  * mov rsi, rbx                    // prog start time
2452  * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
2453  * mov rdi, qword ptr [rbp - 16]   // restore skb pointer from stack
2454  * mov rsi, qword ptr [rbp - 8]    // restore dev pointer from stack
2455  * pop rbx
2456  * leave
2457  * ret
2458  *
2459  * eth_type_trans has 5 byte nop at the beginning. These 5 bytes will be
2460  * replaced with 'call generated_bpf_trampoline'. When it returns
2461  * eth_type_trans will continue executing with original skb and dev pointers.
2462  *
2463  * The assembly code when eth_type_trans is called from trampoline:
2464  *
2465  * push rbp
2466  * mov rbp, rsp
2467  * sub rsp, 24                     // space for skb, dev, return value
2468  * push rbx                        // temp regs to pass start time
2469  * mov qword ptr [rbp - 24], rdi   // save skb pointer to stack
2470  * mov qword ptr [rbp - 16], rsi   // save dev pointer to stack
2471  * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
2472  * mov rbx, rax                    // remember start time if bpf stats are enabled
2473  * lea rdi, [rbp - 24]             // R1==ctx of bpf prog
2474  * call addr_of_jited_FENTRY_prog  // bpf prog can access skb and dev
2475  * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
2476  * mov rsi, rbx                    // prog start time
2477  * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
2478  * mov rdi, qword ptr [rbp - 24]   // restore skb pointer from stack
2479  * mov rsi, qword ptr [rbp - 16]   // restore dev pointer from stack
2480  * call eth_type_trans+5           // execute body of eth_type_trans
2481  * mov qword ptr [rbp - 8], rax    // save return value
2482  * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
2483  * mov rbx, rax                    // remember start time in bpf stats are enabled
2484  * lea rdi, [rbp - 24]             // R1==ctx of bpf prog
2485  * call addr_of_jited_FEXIT_prog   // bpf prog can access skb, dev, return value
2486  * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
2487  * mov rsi, rbx                    // prog start time
2488  * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
2489  * mov rax, qword ptr [rbp - 8]    // restore eth_type_trans's return value
2490  * pop rbx
2491  * leave
2492  * add rsp, 8                      // skip eth_type_trans's frame
2493  * ret                             // return to its caller
2494  */
2495 static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_image,
2496                                          void *rw_image_end, void *image,
2497                                          const struct btf_func_model *m, u32 flags,
2498                                          struct bpf_tramp_links *tlinks,
2499                                          void *func_addr)
2500 {
2501         int i, ret, nr_regs = m->nr_args, stack_size = 0;
2502         int regs_off, nregs_off, ip_off, run_ctx_off, arg_stack_off, rbx_off;
2503         struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
2504         struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
2505         struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
2506         void *orig_call = func_addr;
2507         u8 **branches = NULL;
2508         u8 *prog;
2509         bool save_ret;
2510
2511         /*
2512          * F_INDIRECT is only compatible with F_RET_FENTRY_RET, it is
2513          * explicitly incompatible with F_CALL_ORIG | F_SKIP_FRAME | F_IP_ARG
2514          * because @func_addr.
2515          */
2516         WARN_ON_ONCE((flags & BPF_TRAMP_F_INDIRECT) &&
2517                      (flags & ~(BPF_TRAMP_F_INDIRECT | BPF_TRAMP_F_RET_FENTRY_RET)));
2518
2519         /* extra registers for struct arguments */
2520         for (i = 0; i < m->nr_args; i++) {
2521                 if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
2522                         nr_regs += (m->arg_size[i] + 7) / 8 - 1;
2523         }
2524
2525         /* x86-64 supports up to MAX_BPF_FUNC_ARGS arguments. 1-6
2526          * are passed through regs, the remains are through stack.
2527          */
2528         if (nr_regs > MAX_BPF_FUNC_ARGS)
2529                 return -ENOTSUPP;
2530
2531         /* Generated trampoline stack layout:
2532          *
2533          * RBP + 8         [ return address  ]
2534          * RBP + 0         [ RBP             ]
2535          *
2536          * RBP - 8         [ return value    ]  BPF_TRAMP_F_CALL_ORIG or
2537          *                                      BPF_TRAMP_F_RET_FENTRY_RET flags
2538          *
2539          *                 [ reg_argN        ]  always
2540          *                 [ ...             ]
2541          * RBP - regs_off  [ reg_arg1        ]  program's ctx pointer
2542          *
2543          * RBP - nregs_off [ regs count      ]  always
2544          *
2545          * RBP - ip_off    [ traced function ]  BPF_TRAMP_F_IP_ARG flag
2546          *
2547          * RBP - rbx_off   [ rbx value       ]  always
2548          *
2549          * RBP - run_ctx_off [ bpf_tramp_run_ctx ]
2550          *
2551          *                     [ stack_argN ]  BPF_TRAMP_F_CALL_ORIG
2552          *                     [ ...        ]
2553          *                     [ stack_arg2 ]
2554          * RBP - arg_stack_off [ stack_arg1 ]
2555          * RSP                 [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX
2556          */
2557
2558         /* room for return value of orig_call or fentry prog */
2559         save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
2560         if (save_ret)
2561                 stack_size += 8;
2562
2563         stack_size += nr_regs * 8;
2564         regs_off = stack_size;
2565
2566         /* regs count  */
2567         stack_size += 8;
2568         nregs_off = stack_size;
2569
2570         if (flags & BPF_TRAMP_F_IP_ARG)
2571                 stack_size += 8; /* room for IP address argument */
2572
2573         ip_off = stack_size;
2574
2575         stack_size += 8;
2576         rbx_off = stack_size;
2577
2578         stack_size += (sizeof(struct bpf_tramp_run_ctx) + 7) & ~0x7;
2579         run_ctx_off = stack_size;
2580
2581         if (nr_regs > 6 && (flags & BPF_TRAMP_F_CALL_ORIG)) {
2582                 /* the space that used to pass arguments on-stack */
2583                 stack_size += (nr_regs - get_nr_used_regs(m)) * 8;
2584                 /* make sure the stack pointer is 16-byte aligned if we
2585                  * need pass arguments on stack, which means
2586                  *  [stack_size + 8(rbp) + 8(rip) + 8(origin rip)]
2587                  * should be 16-byte aligned. Following code depend on
2588                  * that stack_size is already 8-byte aligned.
2589                  */
2590                 stack_size += (stack_size % 16) ? 0 : 8;
2591         }
2592
2593         arg_stack_off = stack_size;
2594
2595         if (flags & BPF_TRAMP_F_SKIP_FRAME) {
2596                 /* skip patched call instruction and point orig_call to actual
2597                  * body of the kernel function.
2598                  */
2599                 if (is_endbr(*(u32 *)orig_call))
2600                         orig_call += ENDBR_INSN_SIZE;
2601                 orig_call += X86_PATCH_SIZE;
2602         }
2603
2604         prog = rw_image;
2605
2606         if (flags & BPF_TRAMP_F_INDIRECT) {
2607                 /*
2608                  * Indirect call for bpf_struct_ops
2609                  */
2610                 emit_cfi(&prog, cfi_get_func_hash(func_addr));
2611         } else {
2612                 /*
2613                  * Direct-call fentry stub, as such it needs accounting for the
2614                  * __fentry__ call.
2615                  */
2616                 x86_call_depth_emit_accounting(&prog, NULL);
2617         }
2618         EMIT1(0x55);             /* push rbp */
2619         EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
2620         if (!is_imm8(stack_size)) {
2621                 /* sub rsp, stack_size */
2622                 EMIT3_off32(0x48, 0x81, 0xEC, stack_size);
2623         } else {
2624                 /* sub rsp, stack_size */
2625                 EMIT4(0x48, 0x83, 0xEC, stack_size);
2626         }
2627         if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
2628                 EMIT1(0x50);            /* push rax */
2629         /* mov QWORD PTR [rbp - rbx_off], rbx */
2630         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_6, -rbx_off);
2631
2632         /* Store number of argument registers of the traced function:
2633          *   mov rax, nr_regs
2634          *   mov QWORD PTR [rbp - nregs_off], rax
2635          */
2636         emit_mov_imm64(&prog, BPF_REG_0, 0, (u32) nr_regs);
2637         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -nregs_off);
2638
2639         if (flags & BPF_TRAMP_F_IP_ARG) {
2640                 /* Store IP address of the traced function:
2641                  * movabsq rax, func_addr
2642                  * mov QWORD PTR [rbp - ip_off], rax
2643                  */
2644                 emit_mov_imm64(&prog, BPF_REG_0, (long) func_addr >> 32, (u32) (long) func_addr);
2645                 emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -ip_off);
2646         }
2647
2648         save_args(m, &prog, regs_off, false);
2649
2650         if (flags & BPF_TRAMP_F_CALL_ORIG) {
2651                 /* arg1: mov rdi, im */
2652                 emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
2653                 if (emit_rsb_call(&prog, __bpf_tramp_enter,
2654                                   image + (prog - (u8 *)rw_image))) {
2655                         ret = -EINVAL;
2656                         goto cleanup;
2657                 }
2658         }
2659
2660         if (fentry->nr_links) {
2661                 if (invoke_bpf(m, &prog, fentry, regs_off, run_ctx_off,
2662                                flags & BPF_TRAMP_F_RET_FENTRY_RET, image, rw_image))
2663                         return -EINVAL;
2664         }
2665
2666         if (fmod_ret->nr_links) {
2667                 branches = kcalloc(fmod_ret->nr_links, sizeof(u8 *),
2668                                    GFP_KERNEL);
2669                 if (!branches)
2670                         return -ENOMEM;
2671
2672                 if (invoke_bpf_mod_ret(m, &prog, fmod_ret, regs_off,
2673                                        run_ctx_off, branches, image, rw_image)) {
2674                         ret = -EINVAL;
2675                         goto cleanup;
2676                 }
2677         }
2678
2679         if (flags & BPF_TRAMP_F_CALL_ORIG) {
2680                 restore_regs(m, &prog, regs_off);
2681                 save_args(m, &prog, arg_stack_off, true);
2682
2683                 if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
2684                         /* Before calling the original function, restore the
2685                          * tail_call_cnt from stack to rax.
2686                          */
2687                         RESTORE_TAIL_CALL_CNT(stack_size);
2688                 }
2689
2690                 if (flags & BPF_TRAMP_F_ORIG_STACK) {
2691                         emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8);
2692                         EMIT2(0xff, 0xd3); /* call *rbx */
2693                 } else {
2694                         /* call original function */
2695                         if (emit_rsb_call(&prog, orig_call, image + (prog - (u8 *)rw_image))) {
2696                                 ret = -EINVAL;
2697                                 goto cleanup;
2698                         }
2699                 }
2700                 /* remember return value in a stack for bpf prog to access */
2701                 emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
2702                 im->ip_after_call = image + (prog - (u8 *)rw_image);
2703                 memcpy(prog, x86_nops[5], X86_PATCH_SIZE);
2704                 prog += X86_PATCH_SIZE;
2705         }
2706
2707         if (fmod_ret->nr_links) {
2708                 /* From Intel 64 and IA-32 Architectures Optimization
2709                  * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
2710                  * Coding Rule 11: All branch targets should be 16-byte
2711                  * aligned.
2712                  */
2713                 emit_align(&prog, 16);
2714                 /* Update the branches saved in invoke_bpf_mod_ret with the
2715                  * aligned address of do_fexit.
2716                  */
2717                 for (i = 0; i < fmod_ret->nr_links; i++) {
2718                         emit_cond_near_jump(&branches[i], image + (prog - (u8 *)rw_image),
2719                                             image + (branches[i] - (u8 *)rw_image), X86_JNE);
2720                 }
2721         }
2722
2723         if (fexit->nr_links) {
2724                 if (invoke_bpf(m, &prog, fexit, regs_off, run_ctx_off,
2725                                false, image, rw_image)) {
2726                         ret = -EINVAL;
2727                         goto cleanup;
2728                 }
2729         }
2730
2731         if (flags & BPF_TRAMP_F_RESTORE_REGS)
2732                 restore_regs(m, &prog, regs_off);
2733
2734         /* This needs to be done regardless. If there were fmod_ret programs,
2735          * the return value is only updated on the stack and still needs to be
2736          * restored to R0.
2737          */
2738         if (flags & BPF_TRAMP_F_CALL_ORIG) {
2739                 im->ip_epilogue = image + (prog - (u8 *)rw_image);
2740                 /* arg1: mov rdi, im */
2741                 emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
2742                 if (emit_rsb_call(&prog, __bpf_tramp_exit, image + (prog - (u8 *)rw_image))) {
2743                         ret = -EINVAL;
2744                         goto cleanup;
2745                 }
2746         } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
2747                 /* Before running the original function, restore the
2748                  * tail_call_cnt from stack to rax.
2749                  */
2750                 RESTORE_TAIL_CALL_CNT(stack_size);
2751         }
2752
2753         /* restore return value of orig_call or fentry prog back into RAX */
2754         if (save_ret)
2755                 emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
2756
2757         emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, -rbx_off);
2758         EMIT1(0xC9); /* leave */
2759         if (flags & BPF_TRAMP_F_SKIP_FRAME) {
2760                 /* skip our return address and return to parent */
2761                 EMIT4(0x48, 0x83, 0xC4, 8); /* add rsp, 8 */
2762         }
2763         emit_return(&prog, image + (prog - (u8 *)rw_image));
2764         /* Make sure the trampoline generation logic doesn't overflow */
2765         if (WARN_ON_ONCE(prog > (u8 *)rw_image_end - BPF_INSN_SAFETY)) {
2766                 ret = -EFAULT;
2767                 goto cleanup;
2768         }
2769         ret = prog - (u8 *)rw_image + BPF_INSN_SAFETY;
2770
2771 cleanup:
2772         kfree(branches);
2773         return ret;
2774 }
2775
2776 void *arch_alloc_bpf_trampoline(unsigned int size)
2777 {
2778         return bpf_prog_pack_alloc(size, jit_fill_hole);
2779 }
2780
2781 void arch_free_bpf_trampoline(void *image, unsigned int size)
2782 {
2783         bpf_prog_pack_free(image, size);
2784 }
2785
2786 void arch_protect_bpf_trampoline(void *image, unsigned int size)
2787 {
2788 }
2789
2790 void arch_unprotect_bpf_trampoline(void *image, unsigned int size)
2791 {
2792 }
2793
2794 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *image_end,
2795                                 const struct btf_func_model *m, u32 flags,
2796                                 struct bpf_tramp_links *tlinks,
2797                                 void *func_addr)
2798 {
2799         void *rw_image, *tmp;
2800         int ret;
2801         u32 size = image_end - image;
2802
2803         /* rw_image doesn't need to be in module memory range, so we can
2804          * use kvmalloc.
2805          */
2806         rw_image = kvmalloc(size, GFP_KERNEL);
2807         if (!rw_image)
2808                 return -ENOMEM;
2809
2810         ret = __arch_prepare_bpf_trampoline(im, rw_image, rw_image + size, image, m,
2811                                             flags, tlinks, func_addr);
2812         if (ret < 0)
2813                 goto out;
2814
2815         tmp = bpf_arch_text_copy(image, rw_image, size);
2816         if (IS_ERR(tmp))
2817                 ret = PTR_ERR(tmp);
2818 out:
2819         kvfree(rw_image);
2820         return ret;
2821 }
2822
2823 int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
2824                              struct bpf_tramp_links *tlinks, void *func_addr)
2825 {
2826         struct bpf_tramp_image im;
2827         void *image;
2828         int ret;
2829
2830         /* Allocate a temporary buffer for __arch_prepare_bpf_trampoline().
2831          * This will NOT cause fragmentation in direct map, as we do not
2832          * call set_memory_*() on this buffer.
2833          *
2834          * We cannot use kvmalloc here, because we need image to be in
2835          * module memory range.
2836          */
2837         image = bpf_jit_alloc_exec(PAGE_SIZE);
2838         if (!image)
2839                 return -ENOMEM;
2840
2841         ret = __arch_prepare_bpf_trampoline(&im, image, image + PAGE_SIZE, image,
2842                                             m, flags, tlinks, func_addr);
2843         bpf_jit_free_exec(image);
2844         return ret;
2845 }
2846
2847 static int emit_bpf_dispatcher(u8 **pprog, int a, int b, s64 *progs, u8 *image, u8 *buf)
2848 {
2849         u8 *jg_reloc, *prog = *pprog;
2850         int pivot, err, jg_bytes = 1;
2851         s64 jg_offset;
2852
2853         if (a == b) {
2854                 /* Leaf node of recursion, i.e. not a range of indices
2855                  * anymore.
2856                  */
2857                 EMIT1(add_1mod(0x48, BPF_REG_3));       /* cmp rdx,func */
2858                 if (!is_simm32(progs[a]))
2859                         return -1;
2860                 EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3),
2861                             progs[a]);
2862                 err = emit_cond_near_jump(&prog,        /* je func */
2863                                           (void *)progs[a], image + (prog - buf),
2864                                           X86_JE);
2865                 if (err)
2866                         return err;
2867
2868                 emit_indirect_jump(&prog, 2 /* rdx */, image + (prog - buf));
2869
2870                 *pprog = prog;
2871                 return 0;
2872         }
2873
2874         /* Not a leaf node, so we pivot, and recursively descend into
2875          * the lower and upper ranges.
2876          */
2877         pivot = (b - a) / 2;
2878         EMIT1(add_1mod(0x48, BPF_REG_3));               /* cmp rdx,func */
2879         if (!is_simm32(progs[a + pivot]))
2880                 return -1;
2881         EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3), progs[a + pivot]);
2882
2883         if (pivot > 2) {                                /* jg upper_part */
2884                 /* Require near jump. */
2885                 jg_bytes = 4;
2886                 EMIT2_off32(0x0F, X86_JG + 0x10, 0);
2887         } else {
2888                 EMIT2(X86_JG, 0);
2889         }
2890         jg_reloc = prog;
2891
2892         err = emit_bpf_dispatcher(&prog, a, a + pivot,  /* emit lower_part */
2893                                   progs, image, buf);
2894         if (err)
2895                 return err;
2896
2897         /* From Intel 64 and IA-32 Architectures Optimization
2898          * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
2899          * Coding Rule 11: All branch targets should be 16-byte
2900          * aligned.
2901          */
2902         emit_align(&prog, 16);
2903         jg_offset = prog - jg_reloc;
2904         emit_code(jg_reloc - jg_bytes, jg_offset, jg_bytes);
2905
2906         err = emit_bpf_dispatcher(&prog, a + pivot + 1, /* emit upper_part */
2907                                   b, progs, image, buf);
2908         if (err)
2909                 return err;
2910
2911         *pprog = prog;
2912         return 0;
2913 }
2914
2915 static int cmp_ips(const void *a, const void *b)
2916 {
2917         const s64 *ipa = a;
2918         const s64 *ipb = b;
2919
2920         if (*ipa > *ipb)
2921                 return 1;
2922         if (*ipa < *ipb)
2923                 return -1;
2924         return 0;
2925 }
2926
2927 int arch_prepare_bpf_dispatcher(void *image, void *buf, s64 *funcs, int num_funcs)
2928 {
2929         u8 *prog = buf;
2930
2931         sort(funcs, num_funcs, sizeof(funcs[0]), cmp_ips, NULL);
2932         return emit_bpf_dispatcher(&prog, 0, num_funcs - 1, funcs, image, buf);
2933 }
2934
2935 struct x64_jit_data {
2936         struct bpf_binary_header *rw_header;
2937         struct bpf_binary_header *header;
2938         int *addrs;
2939         u8 *image;
2940         int proglen;
2941         struct jit_context ctx;
2942 };
2943
2944 #define MAX_PASSES 20
2945 #define PADDING_PASSES (MAX_PASSES - 5)
2946
2947 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
2948 {
2949         struct bpf_binary_header *rw_header = NULL;
2950         struct bpf_binary_header *header = NULL;
2951         struct bpf_prog *tmp, *orig_prog = prog;
2952         struct x64_jit_data *jit_data;
2953         int proglen, oldproglen = 0;
2954         struct jit_context ctx = {};
2955         bool tmp_blinded = false;
2956         bool extra_pass = false;
2957         bool padding = false;
2958         u8 *rw_image = NULL;
2959         u8 *image = NULL;
2960         int *addrs;
2961         int pass;
2962         int i;
2963
2964         if (!prog->jit_requested)
2965                 return orig_prog;
2966
2967         tmp = bpf_jit_blind_constants(prog);
2968         /*
2969          * If blinding was requested and we failed during blinding,
2970          * we must fall back to the interpreter.
2971          */
2972         if (IS_ERR(tmp))
2973                 return orig_prog;
2974         if (tmp != prog) {
2975                 tmp_blinded = true;
2976                 prog = tmp;
2977         }
2978
2979         jit_data = prog->aux->jit_data;
2980         if (!jit_data) {
2981                 jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
2982                 if (!jit_data) {
2983                         prog = orig_prog;
2984                         goto out;
2985                 }
2986                 prog->aux->jit_data = jit_data;
2987         }
2988         addrs = jit_data->addrs;
2989         if (addrs) {
2990                 ctx = jit_data->ctx;
2991                 oldproglen = jit_data->proglen;
2992                 image = jit_data->image;
2993                 header = jit_data->header;
2994                 rw_header = jit_data->rw_header;
2995                 rw_image = (void *)rw_header + ((void *)image - (void *)header);
2996                 extra_pass = true;
2997                 padding = true;
2998                 goto skip_init_addrs;
2999         }
3000         addrs = kvmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL);
3001         if (!addrs) {
3002                 prog = orig_prog;
3003                 goto out_addrs;
3004         }
3005
3006         /*
3007          * Before first pass, make a rough estimation of addrs[]
3008          * each BPF instruction is translated to less than 64 bytes
3009          */
3010         for (proglen = 0, i = 0; i <= prog->len; i++) {
3011                 proglen += 64;
3012                 addrs[i] = proglen;
3013         }
3014         ctx.cleanup_addr = proglen;
3015 skip_init_addrs:
3016
3017         /*
3018          * JITed image shrinks with every pass and the loop iterates
3019          * until the image stops shrinking. Very large BPF programs
3020          * may converge on the last pass. In such case do one more
3021          * pass to emit the final image.
3022          */
3023         for (pass = 0; pass < MAX_PASSES || image; pass++) {
3024                 if (!padding && pass >= PADDING_PASSES)
3025                         padding = true;
3026                 proglen = do_jit(prog, addrs, image, rw_image, oldproglen, &ctx, padding);
3027                 if (proglen <= 0) {
3028 out_image:
3029                         image = NULL;
3030                         if (header) {
3031                                 bpf_arch_text_copy(&header->size, &rw_header->size,
3032                                                    sizeof(rw_header->size));
3033                                 bpf_jit_binary_pack_free(header, rw_header);
3034                         }
3035                         /* Fall back to interpreter mode */
3036                         prog = orig_prog;
3037                         if (extra_pass) {
3038                                 prog->bpf_func = NULL;
3039                                 prog->jited = 0;
3040                                 prog->jited_len = 0;
3041                         }
3042                         goto out_addrs;
3043                 }
3044                 if (image) {
3045                         if (proglen != oldproglen) {
3046                                 pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
3047                                        proglen, oldproglen);
3048                                 goto out_image;
3049                         }
3050                         break;
3051                 }
3052                 if (proglen == oldproglen) {
3053                         /*
3054                          * The number of entries in extable is the number of BPF_LDX
3055                          * insns that access kernel memory via "pointer to BTF type".
3056                          * The verifier changed their opcode from LDX|MEM|size
3057                          * to LDX|PROBE_MEM|size to make JITing easier.
3058                          */
3059                         u32 align = __alignof__(struct exception_table_entry);
3060                         u32 extable_size = prog->aux->num_exentries *
3061                                 sizeof(struct exception_table_entry);
3062
3063                         /* allocate module memory for x86 insns and extable */
3064                         header = bpf_jit_binary_pack_alloc(roundup(proglen, align) + extable_size,
3065                                                            &image, align, &rw_header, &rw_image,
3066                                                            jit_fill_hole);
3067                         if (!header) {
3068                                 prog = orig_prog;
3069                                 goto out_addrs;
3070                         }
3071                         prog->aux->extable = (void *) image + roundup(proglen, align);
3072                 }
3073                 oldproglen = proglen;
3074                 cond_resched();
3075         }
3076
3077         if (bpf_jit_enable > 1)
3078                 bpf_jit_dump(prog->len, proglen, pass + 1, rw_image);
3079
3080         if (image) {
3081                 if (!prog->is_func || extra_pass) {
3082                         /*
3083                          * bpf_jit_binary_pack_finalize fails in two scenarios:
3084                          *   1) header is not pointing to proper module memory;
3085                          *   2) the arch doesn't support bpf_arch_text_copy().
3086                          *
3087                          * Both cases are serious bugs and justify WARN_ON.
3088                          */
3089                         if (WARN_ON(bpf_jit_binary_pack_finalize(prog, header, rw_header))) {
3090                                 /* header has been freed */
3091                                 header = NULL;
3092                                 goto out_image;
3093                         }
3094
3095                         bpf_tail_call_direct_fixup(prog);
3096                 } else {
3097                         jit_data->addrs = addrs;
3098                         jit_data->ctx = ctx;
3099                         jit_data->proglen = proglen;
3100                         jit_data->image = image;
3101                         jit_data->header = header;
3102                         jit_data->rw_header = rw_header;
3103                 }
3104                 /*
3105                  * ctx.prog_offset is used when CFI preambles put code *before*
3106                  * the function. See emit_cfi(). For FineIBT specifically this code
3107                  * can also be executed and bpf_prog_kallsyms_add() will
3108                  * generate an additional symbol to cover this, hence also
3109                  * decrement proglen.
3110                  */
3111                 prog->bpf_func = (void *)image + cfi_get_offset();
3112                 prog->jited = 1;
3113                 prog->jited_len = proglen - cfi_get_offset();
3114         } else {
3115                 prog = orig_prog;
3116         }
3117
3118         if (!image || !prog->is_func || extra_pass) {
3119                 if (image)
3120                         bpf_prog_fill_jited_linfo(prog, addrs + 1);
3121 out_addrs:
3122                 kvfree(addrs);
3123                 kfree(jit_data);
3124                 prog->aux->jit_data = NULL;
3125         }
3126 out:
3127         if (tmp_blinded)
3128                 bpf_jit_prog_release_other(prog, prog == orig_prog ?
3129                                            tmp : orig_prog);
3130         return prog;
3131 }
3132
3133 bool bpf_jit_supports_kfunc_call(void)
3134 {
3135         return true;
3136 }
3137
3138 void *bpf_arch_text_copy(void *dst, void *src, size_t len)
3139 {
3140         if (text_poke_copy(dst, src, len) == NULL)
3141                 return ERR_PTR(-EINVAL);
3142         return dst;
3143 }
3144
3145 /* Indicate the JIT backend supports mixing bpf2bpf and tailcalls. */
3146 bool bpf_jit_supports_subprog_tailcalls(void)
3147 {
3148         return true;
3149 }
3150
3151 void bpf_jit_free(struct bpf_prog *prog)
3152 {
3153         if (prog->jited) {
3154                 struct x64_jit_data *jit_data = prog->aux->jit_data;
3155                 struct bpf_binary_header *hdr;
3156
3157                 /*
3158                  * If we fail the final pass of JIT (from jit_subprogs),
3159                  * the program may not be finalized yet. Call finalize here
3160                  * before freeing it.
3161                  */
3162                 if (jit_data) {
3163                         bpf_jit_binary_pack_finalize(prog, jit_data->header,
3164                                                      jit_data->rw_header);
3165                         kvfree(jit_data->addrs);
3166                         kfree(jit_data);
3167                 }
3168                 prog->bpf_func = (void *)prog->bpf_func - cfi_get_offset();
3169                 hdr = bpf_jit_binary_pack_hdr(prog);
3170                 bpf_jit_binary_pack_free(hdr, NULL);
3171                 WARN_ON_ONCE(!bpf_prog_kallsyms_verify_off(prog));
3172         }
3173
3174         bpf_prog_unlock_free(prog);
3175 }
3176
3177 bool bpf_jit_supports_exceptions(void)
3178 {
3179         /* We unwind through both kernel frames (starting from within bpf_throw
3180          * call) and BPF frames. Therefore we require ORC unwinder to be enabled
3181          * to walk kernel frames and reach BPF frames in the stack trace.
3182          */
3183         return IS_ENABLED(CONFIG_UNWINDER_ORC);
3184 }
3185
3186 void arch_bpf_stack_walk(bool (*consume_fn)(void *cookie, u64 ip, u64 sp, u64 bp), void *cookie)
3187 {
3188 #if defined(CONFIG_UNWINDER_ORC)
3189         struct unwind_state state;
3190         unsigned long addr;
3191
3192         for (unwind_start(&state, current, NULL, NULL); !unwind_done(&state);
3193              unwind_next_frame(&state)) {
3194                 addr = unwind_get_return_address(&state);
3195                 if (!addr || !consume_fn(cookie, (u64)addr, (u64)state.sp, (u64)state.bp))
3196                         break;
3197         }
3198         return;
3199 #endif
3200         WARN(1, "verification of programs using bpf_throw should have failed\n");
3201 }
3202
3203 void bpf_arch_poke_desc_update(struct bpf_jit_poke_descriptor *poke,
3204                                struct bpf_prog *new, struct bpf_prog *old)
3205 {
3206         u8 *old_addr, *new_addr, *old_bypass_addr;
3207         int ret;
3208
3209         old_bypass_addr = old ? NULL : poke->bypass_addr;
3210         old_addr = old ? (u8 *)old->bpf_func + poke->adj_off : NULL;
3211         new_addr = new ? (u8 *)new->bpf_func + poke->adj_off : NULL;
3212
3213         /*
3214          * On program loading or teardown, the program's kallsym entry
3215          * might not be in place, so we use __bpf_arch_text_poke to skip
3216          * the kallsyms check.
3217          */
3218         if (new) {
3219                 ret = __bpf_arch_text_poke(poke->tailcall_target,
3220                                            BPF_MOD_JUMP,
3221                                            old_addr, new_addr);
3222                 BUG_ON(ret < 0);
3223                 if (!old) {
3224                         ret = __bpf_arch_text_poke(poke->tailcall_bypass,
3225                                                    BPF_MOD_JUMP,
3226                                                    poke->bypass_addr,
3227                                                    NULL);
3228                         BUG_ON(ret < 0);
3229                 }
3230         } else {
3231                 ret = __bpf_arch_text_poke(poke->tailcall_bypass,
3232                                            BPF_MOD_JUMP,
3233                                            old_bypass_addr,
3234                                            poke->bypass_addr);
3235                 BUG_ON(ret < 0);
3236                 /* let other CPUs finish the execution of program
3237                  * so that it will not possible to expose them
3238                  * to invalid nop, stack unwind, nop state
3239                  */
3240                 if (!ret)
3241                         synchronize_rcu();
3242                 ret = __bpf_arch_text_poke(poke->tailcall_target,
3243                                            BPF_MOD_JUMP,
3244                                            old_addr, NULL);
3245                 BUG_ON(ret < 0);
3246         }
3247 }
This page took 0.231891 seconds and 4 git commands to generate.