]> Git Repo - linux.git/blob - tools/net/ynl/ynl-gen-c.py
treewide: remove meaningless assignments in Makefiles
[linux.git] / tools / net / ynl / ynl-gen-c.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4 import argparse
5 import collections
6 import filecmp
7 import os
8 import re
9 import shutil
10 import tempfile
11 import yaml
12
13 from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
14
15
16 def c_upper(name):
17     return name.upper().replace('-', '_')
18
19
20 def c_lower(name):
21     return name.lower().replace('-', '_')
22
23
24 def limit_to_number(name):
25     """
26     Turn a string limit like u32-max or s64-min into its numerical value
27     """
28     if name[0] == 'u' and name.endswith('-min'):
29         return 0
30     width = int(name[1:-4])
31     if name[0] == 's':
32         width -= 1
33     value = (1 << width) - 1
34     if name[0] == 's' and name.endswith('-min'):
35         value = -value - 1
36     return value
37
38
39 class BaseNlLib:
40     def get_family_id(self):
41         return 'ys->family_id'
42
43     def parse_cb_run(self, cb, data, is_dump=False, indent=1):
44         ind = '\n\t\t' + '\t' * indent + ' '
45         if is_dump:
46             return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
47         else:
48             return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
49                    "ynl_cb_array, NLMSG_MIN_TYPE)"
50
51
52 class Type(SpecAttr):
53     def __init__(self, family, attr_set, attr, value):
54         super().__init__(family, attr_set, attr, value)
55
56         self.attr = attr
57         self.attr_set = attr_set
58         self.type = attr['type']
59         self.checks = attr.get('checks', {})
60
61         self.request = False
62         self.reply = False
63
64         if 'len' in attr:
65             self.len = attr['len']
66
67         if 'nested-attributes' in attr:
68             self.nested_attrs = attr['nested-attributes']
69             if self.nested_attrs == family.name:
70                 self.nested_render_name = c_lower(f"{family.name}")
71             else:
72                 self.nested_render_name = c_lower(f"{family.name}_{self.nested_attrs}")
73
74             if self.nested_attrs in self.family.consts:
75                 self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
76             else:
77                 self.nested_struct_type = 'struct ' + self.nested_render_name
78
79         self.c_name = c_lower(self.name)
80         if self.c_name in _C_KW:
81             self.c_name += '_'
82
83         # Added by resolve():
84         self.enum_name = None
85         delattr(self, "enum_name")
86
87     def get_limit(self, limit, default=None):
88         value = self.checks.get(limit, default)
89         if value is None:
90             return value
91         if not isinstance(value, int):
92             value = limit_to_number(value)
93         return value
94
95     def resolve(self):
96         if 'name-prefix' in self.attr:
97             enum_name = f"{self.attr['name-prefix']}{self.name}"
98         else:
99             enum_name = f"{self.attr_set.name_prefix}{self.name}"
100         self.enum_name = c_upper(enum_name)
101
102     def is_multi_val(self):
103         return None
104
105     def is_scalar(self):
106         return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
107
108     def is_recursive(self):
109         return False
110
111     def is_recursive_for_op(self, ri):
112         return self.is_recursive() and not ri.op
113
114     def presence_type(self):
115         return 'bit'
116
117     def presence_member(self, space, type_filter):
118         if self.presence_type() != type_filter:
119             return
120
121         if self.presence_type() == 'bit':
122             pfx = '__' if space == 'user' else ''
123             return f"{pfx}u32 {self.c_name}:1;"
124
125         if self.presence_type() == 'len':
126             pfx = '__' if space == 'user' else ''
127             return f"{pfx}u32 {self.c_name}_len;"
128
129     def _complex_member_type(self, ri):
130         return None
131
132     def free_needs_iter(self):
133         return False
134
135     def free(self, ri, var, ref):
136         if self.is_multi_val() or self.presence_type() == 'len':
137             ri.cw.p(f'free({var}->{ref}{self.c_name});')
138
139     def arg_member(self, ri):
140         member = self._complex_member_type(ri)
141         if member:
142             arg = [member + ' *' + self.c_name]
143             if self.presence_type() == 'count':
144                 arg += ['unsigned int n_' + self.c_name]
145             return arg
146         raise Exception(f"Struct member not implemented for class type {self.type}")
147
148     def struct_member(self, ri):
149         if self.is_multi_val():
150             ri.cw.p(f"unsigned int n_{self.c_name};")
151         member = self._complex_member_type(ri)
152         if member:
153             ptr = '*' if self.is_multi_val() else ''
154             if self.is_recursive_for_op(ri):
155                 ptr = '*'
156             ri.cw.p(f"{member} {ptr}{self.c_name};")
157             return
158         members = self.arg_member(ri)
159         for one in members:
160             ri.cw.p(one + ';')
161
162     def _attr_policy(self, policy):
163         return '{ .type = ' + policy + ', }'
164
165     def attr_policy(self, cw):
166         policy = c_upper('nla-' + self.attr['type'])
167
168         spec = self._attr_policy(policy)
169         cw.p(f"\t[{self.enum_name}] = {spec},")
170
171     def _mnl_type(self):
172         # mnl does not have helpers for signed integer types
173         # turn signed type into unsigned
174         # this only makes sense for scalar types
175         t = self.type
176         if t[0] == 's':
177             t = 'u' + t[1:]
178         return t
179
180     def _attr_typol(self):
181         raise Exception(f"Type policy not implemented for class type {self.type}")
182
183     def attr_typol(self, cw):
184         typol = self._attr_typol()
185         cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
186
187     def _attr_put_line(self, ri, var, line):
188         if self.presence_type() == 'bit':
189             ri.cw.p(f"if ({var}->_present.{self.c_name})")
190         elif self.presence_type() == 'len':
191             ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
192         ri.cw.p(f"{line};")
193
194     def _attr_put_simple(self, ri, var, put_type):
195         line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
196         self._attr_put_line(ri, var, line)
197
198     def attr_put(self, ri, var):
199         raise Exception(f"Put not implemented for class type {self.type}")
200
201     def _attr_get(self, ri, var):
202         raise Exception(f"Attr get not implemented for class type {self.type}")
203
204     def attr_get(self, ri, var, first):
205         lines, init_lines, local_vars = self._attr_get(ri, var)
206         if type(lines) is str:
207             lines = [lines]
208         if type(init_lines) is str:
209             init_lines = [init_lines]
210
211         kw = 'if' if first else 'else if'
212         ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
213         if local_vars:
214             for local in local_vars:
215                 ri.cw.p(local)
216             ri.cw.nl()
217
218         if not self.is_multi_val():
219             ri.cw.p("if (ynl_attr_validate(yarg, attr))")
220             ri.cw.p("return MNL_CB_ERROR;")
221             if self.presence_type() == 'bit':
222                 ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
223
224         if init_lines:
225             ri.cw.nl()
226             for line in init_lines:
227                 ri.cw.p(line)
228
229         for line in lines:
230             ri.cw.p(line)
231         ri.cw.block_end()
232         return True
233
234     def _setter_lines(self, ri, member, presence):
235         raise Exception(f"Setter not implemented for class type {self.type}")
236
237     def setter(self, ri, space, direction, deref=False, ref=None):
238         ref = (ref if ref else []) + [self.c_name]
239         var = "req"
240         member = f"{var}->{'.'.join(ref)}"
241
242         code = []
243         presence = ''
244         for i in range(0, len(ref)):
245             presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
246             if self.presence_type() == 'bit':
247                 code.append(presence + ' = 1;')
248         code += self._setter_lines(ri, member, presence)
249
250         func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
251         free = bool([x for x in code if 'free(' in x])
252         alloc = bool([x for x in code if 'alloc(' in x])
253         if free and not alloc:
254             func_name = '__' + func_name
255         ri.cw.write_func('static inline void', func_name, body=code,
256                          args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
257
258
259 class TypeUnused(Type):
260     def presence_type(self):
261         return ''
262
263     def arg_member(self, ri):
264         return []
265
266     def _attr_get(self, ri, var):
267         return ['return MNL_CB_ERROR;'], None, None
268
269     def _attr_typol(self):
270         return '.type = YNL_PT_REJECT, '
271
272     def attr_policy(self, cw):
273         pass
274
275     def attr_put(self, ri, var):
276         pass
277
278     def attr_get(self, ri, var, first):
279         pass
280
281     def setter(self, ri, space, direction, deref=False, ref=None):
282         pass
283
284
285 class TypePad(Type):
286     def presence_type(self):
287         return ''
288
289     def arg_member(self, ri):
290         return []
291
292     def _attr_typol(self):
293         return '.type = YNL_PT_IGNORE, '
294
295     def attr_put(self, ri, var):
296         pass
297
298     def attr_get(self, ri, var, first):
299         pass
300
301     def attr_policy(self, cw):
302         pass
303
304     def setter(self, ri, space, direction, deref=False, ref=None):
305         pass
306
307
308 class TypeScalar(Type):
309     def __init__(self, family, attr_set, attr, value):
310         super().__init__(family, attr_set, attr, value)
311
312         self.byte_order_comment = ''
313         if 'byte-order' in attr:
314             self.byte_order_comment = f" /* {attr['byte-order']} */"
315
316         if 'enum' in self.attr:
317             enum = self.family.consts[self.attr['enum']]
318             low, high = enum.value_range()
319             if 'min' not in self.checks:
320                 if low != 0 or self.type[0] == 's':
321                     self.checks['min'] = low
322             if 'max' not in self.checks:
323                 self.checks['max'] = high
324
325         if 'min' in self.checks and 'max' in self.checks:
326             if self.get_limit('min') > self.get_limit('max'):
327                 raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
328             self.checks['range'] = True
329
330         low = min(self.get_limit('min', 0), self.get_limit('max', 0))
331         high = max(self.get_limit('min', 0), self.get_limit('max', 0))
332         if low < 0 and self.type[0] == 'u':
333             raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
334         if low < -32768 or high > 32767:
335             self.checks['full-range'] = True
336
337         # Added by resolve():
338         self.is_bitfield = None
339         delattr(self, "is_bitfield")
340         self.type_name = None
341         delattr(self, "type_name")
342
343     def resolve(self):
344         self.resolve_up(super())
345
346         if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
347             self.is_bitfield = True
348         elif 'enum' in self.attr:
349             self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
350         else:
351             self.is_bitfield = False
352
353         if not self.is_bitfield and 'enum' in self.attr:
354             self.type_name = self.family.consts[self.attr['enum']].user_type
355         elif self.is_auto_scalar:
356             self.type_name = '__' + self.type[0] + '64'
357         else:
358             self.type_name = '__' + self.type
359
360     def mnl_type(self):
361         return self._mnl_type()
362
363     def _attr_policy(self, policy):
364         if 'flags-mask' in self.checks or self.is_bitfield:
365             if self.is_bitfield:
366                 enum = self.family.consts[self.attr['enum']]
367                 mask = enum.get_mask(as_flags=True)
368             else:
369                 flags = self.family.consts[self.checks['flags-mask']]
370                 flag_cnt = len(flags['entries'])
371                 mask = (1 << flag_cnt) - 1
372             return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
373         elif 'full-range' in self.checks:
374             return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
375         elif 'range' in self.checks:
376             return f"NLA_POLICY_RANGE({policy}, {self.get_limit('min')}, {self.get_limit('max')})"
377         elif 'min' in self.checks:
378             return f"NLA_POLICY_MIN({policy}, {self.get_limit('min')})"
379         elif 'max' in self.checks:
380             return f"NLA_POLICY_MAX({policy}, {self.get_limit('max')})"
381         return super()._attr_policy(policy)
382
383     def _attr_typol(self):
384         return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
385
386     def arg_member(self, ri):
387         return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
388
389     def attr_put(self, ri, var):
390         self._attr_put_simple(ri, var, self.mnl_type())
391
392     def _attr_get(self, ri, var):
393         return f"{var}->{self.c_name} = mnl_attr_get_{self.mnl_type()}(attr);", None, None
394
395     def _setter_lines(self, ri, member, presence):
396         return [f"{member} = {self.c_name};"]
397
398
399 class TypeFlag(Type):
400     def arg_member(self, ri):
401         return []
402
403     def _attr_typol(self):
404         return '.type = YNL_PT_FLAG, '
405
406     def attr_put(self, ri, var):
407         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
408
409     def _attr_get(self, ri, var):
410         return [], None, None
411
412     def _setter_lines(self, ri, member, presence):
413         return []
414
415
416 class TypeString(Type):
417     def arg_member(self, ri):
418         return [f"const char *{self.c_name}"]
419
420     def presence_type(self):
421         return 'len'
422
423     def struct_member(self, ri):
424         ri.cw.p(f"char *{self.c_name};")
425
426     def _attr_typol(self):
427         return f'.type = YNL_PT_NUL_STR, '
428
429     def _attr_policy(self, policy):
430         if 'exact-len' in self.checks:
431             mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
432         else:
433             mem = '{ .type = ' + policy
434             if 'max-len' in self.checks:
435                 mem += ', .len = ' + str(self.get_limit('max-len'))
436             mem += ', }'
437         return mem
438
439     def attr_policy(self, cw):
440         if self.checks.get('unterminated-ok', False):
441             policy = 'NLA_STRING'
442         else:
443             policy = 'NLA_NUL_STRING'
444
445         spec = self._attr_policy(policy)
446         cw.p(f"\t[{self.enum_name}] = {spec},")
447
448     def attr_put(self, ri, var):
449         self._attr_put_simple(ri, var, 'strz')
450
451     def _attr_get(self, ri, var):
452         len_mem = var + '->_present.' + self.c_name + '_len'
453         return [f"{len_mem} = len;",
454                 f"{var}->{self.c_name} = malloc(len + 1);",
455                 f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
456                 f"{var}->{self.c_name}[len] = 0;"], \
457                ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
458                ['unsigned int len;']
459
460     def _setter_lines(self, ri, member, presence):
461         return [f"free({member});",
462                 f"{presence}_len = strlen({self.c_name});",
463                 f"{member} = malloc({presence}_len + 1);",
464                 f'memcpy({member}, {self.c_name}, {presence}_len);',
465                 f'{member}[{presence}_len] = 0;']
466
467
468 class TypeBinary(Type):
469     def arg_member(self, ri):
470         return [f"const void *{self.c_name}", 'size_t len']
471
472     def presence_type(self):
473         return 'len'
474
475     def struct_member(self, ri):
476         ri.cw.p(f"void *{self.c_name};")
477
478     def _attr_typol(self):
479         return f'.type = YNL_PT_BINARY,'
480
481     def _attr_policy(self, policy):
482         if 'exact-len' in self.checks:
483             mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
484         else:
485             mem = '{ '
486             if len(self.checks) == 1 and 'min-len' in self.checks:
487                 mem += '.len = ' + str(self.get_limit('min-len'))
488             elif len(self.checks) == 0:
489                 mem += '.type = NLA_BINARY'
490             else:
491                 raise Exception('One or more of binary type checks not implemented, yet')
492             mem += ', }'
493         return mem
494
495     def attr_put(self, ri, var):
496         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
497                             f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
498
499     def _attr_get(self, ri, var):
500         len_mem = var + '->_present.' + self.c_name + '_len'
501         return [f"{len_mem} = len;",
502                 f"{var}->{self.c_name} = malloc(len);",
503                 f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
504                ['len = mnl_attr_get_payload_len(attr);'], \
505                ['unsigned int len;']
506
507     def _setter_lines(self, ri, member, presence):
508         return [f"free({member});",
509                 f"{presence}_len = len;",
510                 f"{member} = malloc({presence}_len);",
511                 f'memcpy({member}, {self.c_name}, {presence}_len);']
512
513
514 class TypeBitfield32(Type):
515     def _complex_member_type(self, ri):
516         return "struct nla_bitfield32"
517
518     def _attr_typol(self):
519         return f'.type = YNL_PT_BITFIELD32, '
520
521     def _attr_policy(self, policy):
522         if not 'enum' in self.attr:
523             raise Exception('Enum required for bitfield32 attr')
524         enum = self.family.consts[self.attr['enum']]
525         mask = enum.get_mask(as_flags=True)
526         return f"NLA_POLICY_BITFIELD32({mask})"
527
528     def attr_put(self, ri, var):
529         line = f"mnl_attr_put(nlh, {self.enum_name}, sizeof(struct nla_bitfield32), &{var}->{self.c_name})"
530         self._attr_put_line(ri, var, line)
531
532     def _attr_get(self, ri, var):
533         return f"memcpy(&{var}->{self.c_name}, mnl_attr_get_payload(attr), sizeof(struct nla_bitfield32));", None, None
534
535     def _setter_lines(self, ri, member, presence):
536         return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
537
538
539 class TypeNest(Type):
540     def is_recursive(self):
541         return self.family.pure_nested_structs[self.nested_attrs].recursive
542
543     def _complex_member_type(self, ri):
544         return self.nested_struct_type
545
546     def free(self, ri, var, ref):
547         at = '&'
548         if self.is_recursive_for_op(ri):
549             at = ''
550             ri.cw.p(f'if ({var}->{ref}{self.c_name})')
551         ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
552
553     def _attr_typol(self):
554         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
555
556     def _attr_policy(self, policy):
557         return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
558
559     def attr_put(self, ri, var):
560         at = '' if self.is_recursive_for_op(ri) else '&'
561         self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
562                             f"{self.enum_name}, {at}{var}->{self.c_name})")
563
564     def _attr_get(self, ri, var):
565         get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
566                      "return MNL_CB_ERROR;"]
567         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
568                       f"parg.data = &{var}->{self.c_name};"]
569         return get_lines, init_lines, None
570
571     def setter(self, ri, space, direction, deref=False, ref=None):
572         ref = (ref if ref else []) + [self.c_name]
573
574         for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
575             if attr.is_recursive():
576                 continue
577             attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
578
579
580 class TypeMultiAttr(Type):
581     def __init__(self, family, attr_set, attr, value, base_type):
582         super().__init__(family, attr_set, attr, value)
583
584         self.base_type = base_type
585
586     def is_multi_val(self):
587         return True
588
589     def presence_type(self):
590         return 'count'
591
592     def mnl_type(self):
593         return self._mnl_type()
594
595     def _complex_member_type(self, ri):
596         if 'type' not in self.attr or self.attr['type'] == 'nest':
597             return self.nested_struct_type
598         elif self.attr['type'] in scalars:
599             scalar_pfx = '__' if ri.ku_space == 'user' else ''
600             return scalar_pfx + self.attr['type']
601         else:
602             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
603
604     def free_needs_iter(self):
605         return 'type' not in self.attr or self.attr['type'] == 'nest'
606
607     def free(self, ri, var, ref):
608         if self.attr['type'] in scalars:
609             ri.cw.p(f"free({var}->{ref}{self.c_name});")
610         elif 'type' not in self.attr or self.attr['type'] == 'nest':
611             ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
612             ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
613             ri.cw.p(f"free({var}->{ref}{self.c_name});")
614         else:
615             raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
616
617     def _attr_policy(self, policy):
618         return self.base_type._attr_policy(policy)
619
620     def _attr_typol(self):
621         return self.base_type._attr_typol()
622
623     def _attr_get(self, ri, var):
624         return f'n_{self.c_name}++;', None, None
625
626     def attr_put(self, ri, var):
627         if self.attr['type'] in scalars:
628             put_type = self.mnl_type()
629             ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
630             ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
631         elif 'type' not in self.attr or self.attr['type'] == 'nest':
632             ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
633             self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
634                                 f"{self.enum_name}, &{var}->{self.c_name}[i])")
635         else:
636             raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
637
638     def _setter_lines(self, ri, member, presence):
639         # For multi-attr we have a count, not presence, hack up the presence
640         presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
641         return [f"free({member});",
642                 f"{member} = {self.c_name};",
643                 f"{presence} = n_{self.c_name};"]
644
645
646 class TypeArrayNest(Type):
647     def is_multi_val(self):
648         return True
649
650     def presence_type(self):
651         return 'count'
652
653     def _complex_member_type(self, ri):
654         if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
655             return self.nested_struct_type
656         elif self.attr['sub-type'] in scalars:
657             scalar_pfx = '__' if ri.ku_space == 'user' else ''
658             return scalar_pfx + self.attr['sub-type']
659         else:
660             raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
661
662     def _attr_typol(self):
663         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
664
665     def _attr_get(self, ri, var):
666         local_vars = ['const struct nlattr *attr2;']
667         get_lines = [f'attr_{self.c_name} = attr;',
668                      'mnl_attr_for_each_nested(attr2, attr)',
669                      f'\t{var}->n_{self.c_name}++;']
670         return get_lines, None, local_vars
671
672
673 class TypeNestTypeValue(Type):
674     def _complex_member_type(self, ri):
675         return self.nested_struct_type
676
677     def _attr_typol(self):
678         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
679
680     def _attr_get(self, ri, var):
681         prev = 'attr'
682         tv_args = ''
683         get_lines = []
684         local_vars = []
685         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
686                       f"parg.data = &{var}->{self.c_name};"]
687         if 'type-value' in self.attr:
688             tv_names = [c_lower(x) for x in self.attr["type-value"]]
689             local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
690             local_vars += [f'__u32 {", ".join(tv_names)};']
691             for level in self.attr["type-value"]:
692                 level = c_lower(level)
693                 get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
694                 get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
695                 prev = 'attr_' + level
696
697             tv_args = f", {', '.join(tv_names)}"
698
699         get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
700         return get_lines, init_lines, local_vars
701
702
703 class Struct:
704     def __init__(self, family, space_name, type_list=None, inherited=None):
705         self.family = family
706         self.space_name = space_name
707         self.attr_set = family.attr_sets[space_name]
708         # Use list to catch comparisons with empty sets
709         self._inherited = inherited if inherited is not None else []
710         self.inherited = []
711
712         self.nested = type_list is None
713         if family.name == c_lower(space_name):
714             self.render_name = c_lower(family.name)
715         else:
716             self.render_name = c_lower(family.name + '-' + space_name)
717         self.struct_name = 'struct ' + self.render_name
718         if self.nested and space_name in family.consts:
719             self.struct_name += '_'
720         self.ptr_name = self.struct_name + ' *'
721         # All attr sets this one contains, directly or multiple levels down
722         self.child_nests = set()
723
724         self.request = False
725         self.reply = False
726         self.recursive = False
727
728         self.attr_list = []
729         self.attrs = dict()
730         if type_list is not None:
731             for t in type_list:
732                 self.attr_list.append((t, self.attr_set[t]),)
733         else:
734             for t in self.attr_set:
735                 self.attr_list.append((t, self.attr_set[t]),)
736
737         max_val = 0
738         self.attr_max_val = None
739         for name, attr in self.attr_list:
740             if attr.value >= max_val:
741                 max_val = attr.value
742                 self.attr_max_val = attr
743             self.attrs[name] = attr
744
745     def __iter__(self):
746         yield from self.attrs
747
748     def __getitem__(self, key):
749         return self.attrs[key]
750
751     def member_list(self):
752         return self.attr_list
753
754     def set_inherited(self, new_inherited):
755         if self._inherited != new_inherited:
756             raise Exception("Inheriting different members not supported")
757         self.inherited = [c_lower(x) for x in sorted(self._inherited)]
758
759
760 class EnumEntry(SpecEnumEntry):
761     def __init__(self, enum_set, yaml, prev, value_start):
762         super().__init__(enum_set, yaml, prev, value_start)
763
764         if prev:
765             self.value_change = (self.value != prev.value + 1)
766         else:
767             self.value_change = (self.value != 0)
768         self.value_change = self.value_change or self.enum_set['type'] == 'flags'
769
770         # Added by resolve:
771         self.c_name = None
772         delattr(self, "c_name")
773
774     def resolve(self):
775         self.resolve_up(super())
776
777         self.c_name = c_upper(self.enum_set.value_pfx + self.name)
778
779
780 class EnumSet(SpecEnumSet):
781     def __init__(self, family, yaml):
782         self.render_name = c_lower(family.name + '-' + yaml['name'])
783
784         if 'enum-name' in yaml:
785             if yaml['enum-name']:
786                 self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
787                 self.user_type = self.enum_name
788             else:
789                 self.enum_name = None
790         else:
791             self.enum_name = 'enum ' + self.render_name
792
793         if self.enum_name:
794             self.user_type = self.enum_name
795         else:
796             self.user_type = 'int'
797
798         self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
799
800         super().__init__(family, yaml)
801
802     def new_entry(self, entry, prev_entry, value_start):
803         return EnumEntry(self, entry, prev_entry, value_start)
804
805     def value_range(self):
806         low = min([x.value for x in self.entries.values()])
807         high = max([x.value for x in self.entries.values()])
808
809         if high - low + 1 != len(self.entries):
810             raise Exception("Can't get value range for a noncontiguous enum")
811
812         return low, high
813
814
815 class AttrSet(SpecAttrSet):
816     def __init__(self, family, yaml):
817         super().__init__(family, yaml)
818
819         if self.subset_of is None:
820             if 'name-prefix' in yaml:
821                 pfx = yaml['name-prefix']
822             elif self.name == family.name:
823                 pfx = family.name + '-a-'
824             else:
825                 pfx = f"{family.name}-a-{self.name}-"
826             self.name_prefix = c_upper(pfx)
827             self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
828             self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
829         else:
830             self.name_prefix = family.attr_sets[self.subset_of].name_prefix
831             self.max_name = family.attr_sets[self.subset_of].max_name
832             self.cnt_name = family.attr_sets[self.subset_of].cnt_name
833
834         # Added by resolve:
835         self.c_name = None
836         delattr(self, "c_name")
837
838     def resolve(self):
839         self.c_name = c_lower(self.name)
840         if self.c_name in _C_KW:
841             self.c_name += '_'
842         if self.c_name == self.family.c_name:
843             self.c_name = ''
844
845     def new_attr(self, elem, value):
846         if elem['type'] in scalars:
847             t = TypeScalar(self.family, self, elem, value)
848         elif elem['type'] == 'unused':
849             t = TypeUnused(self.family, self, elem, value)
850         elif elem['type'] == 'pad':
851             t = TypePad(self.family, self, elem, value)
852         elif elem['type'] == 'flag':
853             t = TypeFlag(self.family, self, elem, value)
854         elif elem['type'] == 'string':
855             t = TypeString(self.family, self, elem, value)
856         elif elem['type'] == 'binary':
857             t = TypeBinary(self.family, self, elem, value)
858         elif elem['type'] == 'bitfield32':
859             t = TypeBitfield32(self.family, self, elem, value)
860         elif elem['type'] == 'nest':
861             t = TypeNest(self.family, self, elem, value)
862         elif elem['type'] == 'array-nest':
863             t = TypeArrayNest(self.family, self, elem, value)
864         elif elem['type'] == 'nest-type-value':
865             t = TypeNestTypeValue(self.family, self, elem, value)
866         else:
867             raise Exception(f"No typed class for type {elem['type']}")
868
869         if 'multi-attr' in elem and elem['multi-attr']:
870             t = TypeMultiAttr(self.family, self, elem, value, t)
871
872         return t
873
874
875 class Operation(SpecOperation):
876     def __init__(self, family, yaml, req_value, rsp_value):
877         super().__init__(family, yaml, req_value, rsp_value)
878
879         self.render_name = c_lower(family.name + '_' + self.name)
880
881         self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
882                          ('dump' in yaml and 'request' in yaml['dump'])
883
884         self.has_ntf = False
885
886         # Added by resolve:
887         self.enum_name = None
888         delattr(self, "enum_name")
889
890     def resolve(self):
891         self.resolve_up(super())
892
893         if not self.is_async:
894             self.enum_name = self.family.op_prefix + c_upper(self.name)
895         else:
896             self.enum_name = self.family.async_op_prefix + c_upper(self.name)
897
898     def mark_has_ntf(self):
899         self.has_ntf = True
900
901
902 class Family(SpecFamily):
903     def __init__(self, file_name, exclude_ops):
904         # Added by resolve:
905         self.c_name = None
906         delattr(self, "c_name")
907         self.op_prefix = None
908         delattr(self, "op_prefix")
909         self.async_op_prefix = None
910         delattr(self, "async_op_prefix")
911         self.mcgrps = None
912         delattr(self, "mcgrps")
913         self.consts = None
914         delattr(self, "consts")
915         self.hooks = None
916         delattr(self, "hooks")
917
918         super().__init__(file_name, exclude_ops=exclude_ops)
919
920         self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
921         self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
922
923         if 'definitions' not in self.yaml:
924             self.yaml['definitions'] = []
925
926         if 'uapi-header' in self.yaml:
927             self.uapi_header = self.yaml['uapi-header']
928         else:
929             self.uapi_header = f"linux/{self.name}.h"
930         if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
931             self.uapi_header_name = self.uapi_header[6:-2]
932         else:
933             self.uapi_header_name = self.name
934
935     def resolve(self):
936         self.resolve_up(super())
937
938         if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
939             raise Exception("Codegen only supported for genetlink")
940
941         self.c_name = c_lower(self.name)
942         if 'name-prefix' in self.yaml['operations']:
943             self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
944         else:
945             self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
946         if 'async-prefix' in self.yaml['operations']:
947             self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
948         else:
949             self.async_op_prefix = self.op_prefix
950
951         self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
952
953         self.hooks = dict()
954         for when in ['pre', 'post']:
955             self.hooks[when] = dict()
956             for op_mode in ['do', 'dump']:
957                 self.hooks[when][op_mode] = dict()
958                 self.hooks[when][op_mode]['set'] = set()
959                 self.hooks[when][op_mode]['list'] = []
960
961         # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
962         self.root_sets = dict()
963         # dict space-name -> set('request', 'reply')
964         self.pure_nested_structs = dict()
965
966         self._mark_notify()
967         self._mock_up_events()
968
969         self._load_root_sets()
970         self._load_nested_sets()
971         self._load_attr_use()
972         self._load_hooks()
973
974         self.kernel_policy = self.yaml.get('kernel-policy', 'split')
975         if self.kernel_policy == 'global':
976             self._load_global_policy()
977
978     def new_enum(self, elem):
979         return EnumSet(self, elem)
980
981     def new_attr_set(self, elem):
982         return AttrSet(self, elem)
983
984     def new_operation(self, elem, req_value, rsp_value):
985         return Operation(self, elem, req_value, rsp_value)
986
987     def _mark_notify(self):
988         for op in self.msgs.values():
989             if 'notify' in op:
990                 self.ops[op['notify']].mark_has_ntf()
991
992     # Fake a 'do' equivalent of all events, so that we can render their response parsing
993     def _mock_up_events(self):
994         for op in self.yaml['operations']['list']:
995             if 'event' in op:
996                 op['do'] = {
997                     'reply': {
998                         'attributes': op['event']['attributes']
999                     }
1000                 }
1001
1002     def _load_root_sets(self):
1003         for op_name, op in self.msgs.items():
1004             if 'attribute-set' not in op:
1005                 continue
1006
1007             req_attrs = set()
1008             rsp_attrs = set()
1009             for op_mode in ['do', 'dump']:
1010                 if op_mode in op and 'request' in op[op_mode]:
1011                     req_attrs.update(set(op[op_mode]['request']['attributes']))
1012                 if op_mode in op and 'reply' in op[op_mode]:
1013                     rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1014             if 'event' in op:
1015                 rsp_attrs.update(set(op['event']['attributes']))
1016
1017             if op['attribute-set'] not in self.root_sets:
1018                 self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1019             else:
1020                 self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1021                 self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1022
1023     def _sort_pure_types(self):
1024         # Try to reorder according to dependencies
1025         pns_key_list = list(self.pure_nested_structs.keys())
1026         pns_key_seen = set()
1027         rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1028         for _ in range(rounds):
1029             if len(pns_key_list) == 0:
1030                 break
1031             name = pns_key_list.pop(0)
1032             finished = True
1033             for _, spec in self.attr_sets[name].items():
1034                 if 'nested-attributes' in spec:
1035                     nested = spec['nested-attributes']
1036                     # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1037                     if self.pure_nested_structs[nested].recursive:
1038                         continue
1039                     if nested not in pns_key_seen:
1040                         # Dicts are sorted, this will make struct last
1041                         struct = self.pure_nested_structs.pop(name)
1042                         self.pure_nested_structs[name] = struct
1043                         finished = False
1044                         break
1045             if finished:
1046                 pns_key_seen.add(name)
1047             else:
1048                 pns_key_list.append(name)
1049
1050     def _load_nested_sets(self):
1051         attr_set_queue = list(self.root_sets.keys())
1052         attr_set_seen = set(self.root_sets.keys())
1053
1054         while len(attr_set_queue):
1055             a_set = attr_set_queue.pop(0)
1056             for attr, spec in self.attr_sets[a_set].items():
1057                 if 'nested-attributes' not in spec:
1058                     continue
1059
1060                 nested = spec['nested-attributes']
1061                 if nested not in attr_set_seen:
1062                     attr_set_queue.append(nested)
1063                     attr_set_seen.add(nested)
1064
1065                 inherit = set()
1066                 if nested not in self.root_sets:
1067                     if nested not in self.pure_nested_structs:
1068                         self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1069                 else:
1070                     raise Exception(f'Using attr set as root and nested not supported - {nested}')
1071
1072                 if 'type-value' in spec:
1073                     if nested in self.root_sets:
1074                         raise Exception("Inheriting members to a space used as root not supported")
1075                     inherit.update(set(spec['type-value']))
1076                 elif spec['type'] == 'array-nest':
1077                     inherit.add('idx')
1078                 self.pure_nested_structs[nested].set_inherited(inherit)
1079
1080         for root_set, rs_members in self.root_sets.items():
1081             for attr, spec in self.attr_sets[root_set].items():
1082                 if 'nested-attributes' in spec:
1083                     nested = spec['nested-attributes']
1084                     if attr in rs_members['request']:
1085                         self.pure_nested_structs[nested].request = True
1086                     if attr in rs_members['reply']:
1087                         self.pure_nested_structs[nested].reply = True
1088
1089         self._sort_pure_types()
1090
1091         # Propagate the request / reply / recursive
1092         for attr_set, struct in reversed(self.pure_nested_structs.items()):
1093             for _, spec in self.attr_sets[attr_set].items():
1094                 if 'nested-attributes' in spec:
1095                     child_name = spec['nested-attributes']
1096                     struct.child_nests.add(child_name)
1097                     child = self.pure_nested_structs.get(child_name)
1098                     if child:
1099                         if not child.recursive:
1100                             struct.child_nests.update(child.child_nests)
1101                         child.request |= struct.request
1102                         child.reply |= struct.reply
1103                 if attr_set in struct.child_nests:
1104                     struct.recursive = True
1105
1106         self._sort_pure_types()
1107
1108     def _load_attr_use(self):
1109         for _, struct in self.pure_nested_structs.items():
1110             if struct.request:
1111                 for _, arg in struct.member_list():
1112                     arg.request = True
1113             if struct.reply:
1114                 for _, arg in struct.member_list():
1115                     arg.reply = True
1116
1117         for root_set, rs_members in self.root_sets.items():
1118             for attr, spec in self.attr_sets[root_set].items():
1119                 if attr in rs_members['request']:
1120                     spec.request = True
1121                 if attr in rs_members['reply']:
1122                     spec.reply = True
1123
1124     def _load_global_policy(self):
1125         global_set = set()
1126         attr_set_name = None
1127         for op_name, op in self.ops.items():
1128             if not op:
1129                 continue
1130             if 'attribute-set' not in op:
1131                 continue
1132
1133             if attr_set_name is None:
1134                 attr_set_name = op['attribute-set']
1135             if attr_set_name != op['attribute-set']:
1136                 raise Exception('For a global policy all ops must use the same set')
1137
1138             for op_mode in ['do', 'dump']:
1139                 if op_mode in op:
1140                     req = op[op_mode].get('request')
1141                     if req:
1142                         global_set.update(req.get('attributes', []))
1143
1144         self.global_policy = []
1145         self.global_policy_set = attr_set_name
1146         for attr in self.attr_sets[attr_set_name]:
1147             if attr in global_set:
1148                 self.global_policy.append(attr)
1149
1150     def _load_hooks(self):
1151         for op in self.ops.values():
1152             for op_mode in ['do', 'dump']:
1153                 if op_mode not in op:
1154                     continue
1155                 for when in ['pre', 'post']:
1156                     if when not in op[op_mode]:
1157                         continue
1158                     name = op[op_mode][when]
1159                     if name in self.hooks[when][op_mode]['set']:
1160                         continue
1161                     self.hooks[when][op_mode]['set'].add(name)
1162                     self.hooks[when][op_mode]['list'].append(name)
1163
1164
1165 class RenderInfo:
1166     def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1167         self.family = family
1168         self.nl = cw.nlib
1169         self.ku_space = ku_space
1170         self.op_mode = op_mode
1171         self.op = op
1172
1173         self.fixed_hdr = None
1174         if op and op.fixed_header:
1175             self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1176
1177         # 'do' and 'dump' response parsing is identical
1178         self.type_consistent = True
1179         if op_mode != 'do' and 'dump' in op:
1180             if 'do' in op:
1181                 if ('reply' in op['do']) != ('reply' in op["dump"]):
1182                     self.type_consistent = False
1183                 elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1184                     self.type_consistent = False
1185             else:
1186                 self.type_consistent = False
1187
1188         self.attr_set = attr_set
1189         if not self.attr_set:
1190             self.attr_set = op['attribute-set']
1191
1192         self.type_name_conflict = False
1193         if op:
1194             self.type_name = c_lower(op.name)
1195         else:
1196             self.type_name = c_lower(attr_set)
1197             if attr_set in family.consts:
1198                 self.type_name_conflict = True
1199
1200         self.cw = cw
1201
1202         self.struct = dict()
1203         if op_mode == 'notify':
1204             op_mode = 'do'
1205         for op_dir in ['request', 'reply']:
1206             if op:
1207                 type_list = []
1208                 if op_dir in op[op_mode]:
1209                     type_list = op[op_mode][op_dir]['attributes']
1210                 self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1211         if op_mode == 'event':
1212             self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1213
1214
1215 class CodeWriter:
1216     def __init__(self, nlib, out_file=None, overwrite=True):
1217         self.nlib = nlib
1218         self._overwrite = overwrite
1219
1220         self._nl = False
1221         self._block_end = False
1222         self._silent_block = False
1223         self._ind = 0
1224         self._ifdef_block = None
1225         if out_file is None:
1226             self._out = os.sys.stdout
1227         else:
1228             self._out = tempfile.NamedTemporaryFile('w+')
1229             self._out_file = out_file
1230
1231     def __del__(self):
1232         self.close_out_file()
1233
1234     def close_out_file(self):
1235         if self._out == os.sys.stdout:
1236             return
1237         # Avoid modifying the file if contents didn't change
1238         self._out.flush()
1239         if not self._overwrite and os.path.isfile(self._out_file):
1240             if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1241                 return
1242         with open(self._out_file, 'w+') as out_file:
1243             self._out.seek(0)
1244             shutil.copyfileobj(self._out, out_file)
1245             self._out.close()
1246         self._out = os.sys.stdout
1247
1248     @classmethod
1249     def _is_cond(cls, line):
1250         return line.startswith('if') or line.startswith('while') or line.startswith('for')
1251
1252     def p(self, line, add_ind=0):
1253         if self._block_end:
1254             self._block_end = False
1255             if line.startswith('else'):
1256                 line = '} ' + line
1257             else:
1258                 self._out.write('\t' * self._ind + '}\n')
1259
1260         if self._nl:
1261             self._out.write('\n')
1262             self._nl = False
1263
1264         ind = self._ind
1265         if line[-1] == ':':
1266             ind -= 1
1267         if self._silent_block:
1268             ind += 1
1269         self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1270         if line[0] == '#':
1271             ind = 0
1272         if add_ind:
1273             ind += add_ind
1274         self._out.write('\t' * ind + line + '\n')
1275
1276     def nl(self):
1277         self._nl = True
1278
1279     def block_start(self, line=''):
1280         if line:
1281             line = line + ' '
1282         self.p(line + '{')
1283         self._ind += 1
1284
1285     def block_end(self, line=''):
1286         if line and line[0] not in {';', ','}:
1287             line = ' ' + line
1288         self._ind -= 1
1289         self._nl = False
1290         if not line:
1291             # Delay printing closing bracket in case "else" comes next
1292             if self._block_end:
1293                 self._out.write('\t' * (self._ind + 1) + '}\n')
1294             self._block_end = True
1295         else:
1296             self.p('}' + line)
1297
1298     def write_doc_line(self, doc, indent=True):
1299         words = doc.split()
1300         line = ' *'
1301         for word in words:
1302             if len(line) + len(word) >= 79:
1303                 self.p(line)
1304                 line = ' *'
1305                 if indent:
1306                     line += '  '
1307             line += ' ' + word
1308         self.p(line)
1309
1310     def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1311         if not args:
1312             args = ['void']
1313
1314         if doc:
1315             self.p('/*')
1316             self.p(' * ' + doc)
1317             self.p(' */')
1318
1319         oneline = qual_ret
1320         if qual_ret[-1] != '*':
1321             oneline += ' '
1322         oneline += f"{name}({', '.join(args)}){suffix}"
1323
1324         if len(oneline) < 80:
1325             self.p(oneline)
1326             return
1327
1328         v = qual_ret
1329         if len(v) > 3:
1330             self.p(v)
1331             v = ''
1332         elif qual_ret[-1] != '*':
1333             v += ' '
1334         v += name + '('
1335         ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1336         delta_ind = len(v) - len(ind)
1337         v += args[0]
1338         i = 1
1339         while i < len(args):
1340             next_len = len(v) + len(args[i])
1341             if v[0] == '\t':
1342                 next_len += delta_ind
1343             if next_len > 76:
1344                 self.p(v + ',')
1345                 v = ind
1346             else:
1347                 v += ', '
1348             v += args[i]
1349             i += 1
1350         self.p(v + ')' + suffix)
1351
1352     def write_func_lvar(self, local_vars):
1353         if not local_vars:
1354             return
1355
1356         if type(local_vars) is str:
1357             local_vars = [local_vars]
1358
1359         local_vars.sort(key=len, reverse=True)
1360         for var in local_vars:
1361             self.p(var)
1362         self.nl()
1363
1364     def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1365         self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1366         self.write_func_lvar(local_vars=local_vars)
1367
1368         self.block_start()
1369         for line in body:
1370             self.p(line)
1371         self.block_end()
1372
1373     def writes_defines(self, defines):
1374         longest = 0
1375         for define in defines:
1376             if len(define[0]) > longest:
1377                 longest = len(define[0])
1378         longest = ((longest + 8) // 8) * 8
1379         for define in defines:
1380             line = '#define ' + define[0]
1381             line += '\t' * ((longest - len(define[0]) + 7) // 8)
1382             if type(define[1]) is int:
1383                 line += str(define[1])
1384             elif type(define[1]) is str:
1385                 line += '"' + define[1] + '"'
1386             self.p(line)
1387
1388     def write_struct_init(self, members):
1389         longest = max([len(x[0]) for x in members])
1390         longest += 1  # because we prepend a .
1391         longest = ((longest + 8) // 8) * 8
1392         for one in members:
1393             line = '.' + one[0]
1394             line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1395             line += '= ' + str(one[1]) + ','
1396             self.p(line)
1397
1398     def ifdef_block(self, config):
1399         config_option = None
1400         if config:
1401             config_option = 'CONFIG_' + c_upper(config)
1402         if self._ifdef_block == config_option:
1403             return
1404
1405         if self._ifdef_block:
1406             self.p('#endif /* ' + self._ifdef_block + ' */')
1407         if config_option:
1408             self.p('#ifdef ' + config_option)
1409         self._ifdef_block = config_option
1410
1411
1412 scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1413
1414 direction_to_suffix = {
1415     'reply': '_rsp',
1416     'request': '_req',
1417     '': ''
1418 }
1419
1420 op_mode_to_wrapper = {
1421     'do': '',
1422     'dump': '_list',
1423     'notify': '_ntf',
1424     'event': '',
1425 }
1426
1427 _C_KW = {
1428     'auto',
1429     'bool',
1430     'break',
1431     'case',
1432     'char',
1433     'const',
1434     'continue',
1435     'default',
1436     'do',
1437     'double',
1438     'else',
1439     'enum',
1440     'extern',
1441     'float',
1442     'for',
1443     'goto',
1444     'if',
1445     'inline',
1446     'int',
1447     'long',
1448     'register',
1449     'return',
1450     'short',
1451     'signed',
1452     'sizeof',
1453     'static',
1454     'struct',
1455     'switch',
1456     'typedef',
1457     'union',
1458     'unsigned',
1459     'void',
1460     'volatile',
1461     'while'
1462 }
1463
1464
1465 def rdir(direction):
1466     if direction == 'reply':
1467         return 'request'
1468     if direction == 'request':
1469         return 'reply'
1470     return direction
1471
1472
1473 def op_prefix(ri, direction, deref=False):
1474     suffix = f"_{ri.type_name}"
1475
1476     if not ri.op_mode or ri.op_mode == 'do':
1477         suffix += f"{direction_to_suffix[direction]}"
1478     else:
1479         if direction == 'request':
1480             suffix += '_req_dump'
1481         else:
1482             if ri.type_consistent:
1483                 if deref:
1484                     suffix += f"{direction_to_suffix[direction]}"
1485                 else:
1486                     suffix += op_mode_to_wrapper[ri.op_mode]
1487             else:
1488                 suffix += '_rsp'
1489                 suffix += '_dump' if deref else '_list'
1490
1491     return f"{ri.family.c_name}{suffix}"
1492
1493
1494 def type_name(ri, direction, deref=False):
1495     return f"struct {op_prefix(ri, direction, deref=deref)}"
1496
1497
1498 def print_prototype(ri, direction, terminate=True, doc=None):
1499     suffix = ';' if terminate else ''
1500
1501     fname = ri.op.render_name
1502     if ri.op_mode == 'dump':
1503         fname += '_dump'
1504
1505     args = ['struct ynl_sock *ys']
1506     if 'request' in ri.op[ri.op_mode]:
1507         args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1508
1509     ret = 'int'
1510     if 'reply' in ri.op[ri.op_mode]:
1511         ret = f"{type_name(ri, rdir(direction))} *"
1512
1513     ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1514
1515
1516 def print_req_prototype(ri):
1517     print_prototype(ri, "request", doc=ri.op['doc'])
1518
1519
1520 def print_dump_prototype(ri):
1521     print_prototype(ri, "request")
1522
1523
1524 def put_typol_fwd(cw, struct):
1525     cw.p(f'extern struct ynl_policy_nest {struct.render_name}_nest;')
1526
1527
1528 def put_typol(cw, struct):
1529     type_max = struct.attr_set.max_name
1530     cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1531
1532     for _, arg in struct.member_list():
1533         arg.attr_typol(cw)
1534
1535     cw.block_end(line=';')
1536     cw.nl()
1537
1538     cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1539     cw.p(f'.max_attr = {type_max},')
1540     cw.p(f'.table = {struct.render_name}_policy,')
1541     cw.block_end(line=';')
1542     cw.nl()
1543
1544
1545 def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1546     args = [f'int {arg_name}']
1547     if enum:
1548         args = [enum.user_type + ' ' + arg_name]
1549     cw.write_func_prot('const char *', f'{render_name}_str', args)
1550     cw.block_start()
1551     if enum and enum.type == 'flags':
1552         cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1553     cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1554     cw.p('return NULL;')
1555     cw.p(f'return {map_name}[{arg_name}];')
1556     cw.block_end()
1557     cw.nl()
1558
1559
1560 def put_op_name_fwd(family, cw):
1561     cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1562
1563
1564 def put_op_name(family, cw):
1565     map_name = f'{family.c_name}_op_strmap'
1566     cw.block_start(line=f"static const char * const {map_name}[] =")
1567     for op_name, op in family.msgs.items():
1568         if op.rsp_value:
1569             # Make sure we don't add duplicated entries, if multiple commands
1570             # produce the same response in legacy families.
1571             if family.rsp_by_value[op.rsp_value] != op:
1572                 cw.p(f'// skip "{op_name}", duplicate reply value')
1573                 continue
1574
1575             if op.req_value == op.rsp_value:
1576                 cw.p(f'[{op.enum_name}] = "{op_name}",')
1577             else:
1578                 cw.p(f'[{op.rsp_value}] = "{op_name}",')
1579     cw.block_end(line=';')
1580     cw.nl()
1581
1582     _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1583
1584
1585 def put_enum_to_str_fwd(family, cw, enum):
1586     args = [enum.user_type + ' value']
1587     cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1588
1589
1590 def put_enum_to_str(family, cw, enum):
1591     map_name = f'{enum.render_name}_strmap'
1592     cw.block_start(line=f"static const char * const {map_name}[] =")
1593     for entry in enum.entries.values():
1594         cw.p(f'[{entry.value}] = "{entry.name}",')
1595     cw.block_end(line=';')
1596     cw.nl()
1597
1598     _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1599
1600
1601 def put_req_nested_prototype(ri, struct, suffix=';'):
1602     func_args = ['struct nlmsghdr *nlh',
1603                  'unsigned int attr_type',
1604                  f'{struct.ptr_name}obj']
1605
1606     ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1607                           suffix=suffix)
1608
1609
1610 def put_req_nested(ri, struct):
1611     put_req_nested_prototype(ri, struct, suffix='')
1612     ri.cw.block_start()
1613     ri.cw.write_func_lvar('struct nlattr *nest;')
1614
1615     ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1616
1617     for _, arg in struct.member_list():
1618         arg.attr_put(ri, "obj")
1619
1620     ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1621
1622     ri.cw.nl()
1623     ri.cw.p('return 0;')
1624     ri.cw.block_end()
1625     ri.cw.nl()
1626
1627
1628 def _multi_parse(ri, struct, init_lines, local_vars):
1629     if struct.nested:
1630         iter_line = "mnl_attr_for_each_nested(attr, nested)"
1631     else:
1632         if ri.fixed_hdr:
1633             local_vars += ['void *hdr;']
1634         iter_line = "mnl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1635
1636     array_nests = set()
1637     multi_attrs = set()
1638     needs_parg = False
1639     for arg, aspec in struct.member_list():
1640         if aspec['type'] == 'array-nest':
1641             local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1642             array_nests.add(arg)
1643         if 'multi-attr' in aspec:
1644             multi_attrs.add(arg)
1645         needs_parg |= 'nested-attributes' in aspec
1646     if array_nests or multi_attrs:
1647         local_vars.append('int i;')
1648     if needs_parg:
1649         local_vars.append('struct ynl_parse_arg parg;')
1650         init_lines.append('parg.ys = yarg->ys;')
1651
1652     all_multi = array_nests | multi_attrs
1653
1654     for anest in sorted(all_multi):
1655         local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1656
1657     ri.cw.block_start()
1658     ri.cw.write_func_lvar(local_vars)
1659
1660     for line in init_lines:
1661         ri.cw.p(line)
1662     ri.cw.nl()
1663
1664     for arg in struct.inherited:
1665         ri.cw.p(f'dst->{arg} = {arg};')
1666
1667     if ri.fixed_hdr:
1668         ri.cw.p('hdr = mnl_nlmsg_get_payload_offset(nlh, sizeof(struct genlmsghdr));')
1669         ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1670     for anest in sorted(all_multi):
1671         aspec = struct[anest]
1672         ri.cw.p(f"if (dst->{aspec.c_name})")
1673         ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1674
1675     ri.cw.nl()
1676     ri.cw.block_start(line=iter_line)
1677     ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1678     ri.cw.nl()
1679
1680     first = True
1681     for _, arg in struct.member_list():
1682         good = arg.attr_get(ri, 'dst', first=first)
1683         # First may be 'unused' or 'pad', ignore those
1684         first &= not good
1685
1686     ri.cw.block_end()
1687     ri.cw.nl()
1688
1689     for anest in sorted(array_nests):
1690         aspec = struct[anest]
1691
1692         ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1693         ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1694         ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1695         ri.cw.p('i = 0;')
1696         ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1697         ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1698         ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1699         ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1700         ri.cw.p('return MNL_CB_ERROR;')
1701         ri.cw.p('i++;')
1702         ri.cw.block_end()
1703         ri.cw.block_end()
1704     ri.cw.nl()
1705
1706     for anest in sorted(multi_attrs):
1707         aspec = struct[anest]
1708         ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1709         ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1710         ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1711         ri.cw.p('i = 0;')
1712         if 'nested-attributes' in aspec:
1713             ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1714         ri.cw.block_start(line=iter_line)
1715         ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1716         if 'nested-attributes' in aspec:
1717             ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1718             ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1719             ri.cw.p('return MNL_CB_ERROR;')
1720         elif aspec.type in scalars:
1721             ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{aspec.mnl_type()}(attr);")
1722         else:
1723             raise Exception('Nest parsing type not supported yet')
1724         ri.cw.p('i++;')
1725         ri.cw.block_end()
1726         ri.cw.block_end()
1727         ri.cw.block_end()
1728     ri.cw.nl()
1729
1730     if struct.nested:
1731         ri.cw.p('return 0;')
1732     else:
1733         ri.cw.p('return MNL_CB_OK;')
1734     ri.cw.block_end()
1735     ri.cw.nl()
1736
1737
1738 def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1739     func_args = ['struct ynl_parse_arg *yarg',
1740                  'const struct nlattr *nested']
1741     for arg in struct.inherited:
1742         func_args.append('__u32 ' + arg)
1743
1744     ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1745                           suffix=suffix)
1746
1747
1748 def parse_rsp_nested(ri, struct):
1749     parse_rsp_nested_prototype(ri, struct, suffix='')
1750
1751     local_vars = ['const struct nlattr *attr;',
1752                   f'{struct.ptr_name}dst = yarg->data;']
1753     init_lines = []
1754
1755     _multi_parse(ri, struct, init_lines, local_vars)
1756
1757
1758 def parse_rsp_msg(ri, deref=False):
1759     if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1760         return
1761
1762     func_args = ['const struct nlmsghdr *nlh',
1763                  'void *data']
1764
1765     local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1766                   'struct ynl_parse_arg *yarg = data;',
1767                   'const struct nlattr *attr;']
1768     init_lines = ['dst = yarg->data;']
1769
1770     ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1771
1772     if ri.struct["reply"].member_list():
1773         _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1774     else:
1775         # Empty reply
1776         ri.cw.block_start()
1777         ri.cw.p('return MNL_CB_OK;')
1778         ri.cw.block_end()
1779         ri.cw.nl()
1780
1781
1782 def print_req(ri):
1783     ret_ok = '0'
1784     ret_err = '-1'
1785     direction = "request"
1786     local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1787                   'struct nlmsghdr *nlh;',
1788                   'int err;']
1789
1790     if 'reply' in ri.op[ri.op_mode]:
1791         ret_ok = 'rsp'
1792         ret_err = 'NULL'
1793         local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1794
1795     if ri.fixed_hdr:
1796         local_vars += ['size_t hdr_len;',
1797                        'void *hdr;']
1798
1799     print_prototype(ri, direction, terminate=False)
1800     ri.cw.block_start()
1801     ri.cw.write_func_lvar(local_vars)
1802
1803     ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1804
1805     ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1806     if 'reply' in ri.op[ri.op_mode]:
1807         ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1808     ri.cw.nl()
1809
1810     if ri.fixed_hdr:
1811         ri.cw.p("hdr_len = sizeof(req->_hdr);")
1812         ri.cw.p("hdr = mnl_nlmsg_put_extra_header(nlh, hdr_len);")
1813         ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1814         ri.cw.nl()
1815
1816     for _, attr in ri.struct["request"].member_list():
1817         attr.attr_put(ri, "req")
1818     ri.cw.nl()
1819
1820     if 'reply' in ri.op[ri.op_mode]:
1821         ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1822         ri.cw.p('yrs.yarg.data = rsp;')
1823         ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1824         if ri.op.value is not None:
1825             ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1826         else:
1827             ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1828         ri.cw.nl()
1829     ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1830     ri.cw.p('if (err < 0)')
1831     if 'reply' in ri.op[ri.op_mode]:
1832         ri.cw.p('goto err_free;')
1833     else:
1834         ri.cw.p('return -1;')
1835     ri.cw.nl()
1836
1837     ri.cw.p(f"return {ret_ok};")
1838     ri.cw.nl()
1839
1840     if 'reply' in ri.op[ri.op_mode]:
1841         ri.cw.p('err_free:')
1842         ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1843         ri.cw.p(f"return {ret_err};")
1844
1845     ri.cw.block_end()
1846
1847
1848 def print_dump(ri):
1849     direction = "request"
1850     print_prototype(ri, direction, terminate=False)
1851     ri.cw.block_start()
1852     local_vars = ['struct ynl_dump_state yds = {};',
1853                   'struct nlmsghdr *nlh;',
1854                   'int err;']
1855
1856     if ri.fixed_hdr:
1857         local_vars += ['size_t hdr_len;',
1858                        'void *hdr;']
1859
1860     ri.cw.write_func_lvar(local_vars)
1861
1862     ri.cw.p('yds.ys = ys;')
1863     ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1864     ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1865     if ri.op.value is not None:
1866         ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1867     else:
1868         ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1869     ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1870     ri.cw.nl()
1871     ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1872
1873     if ri.fixed_hdr:
1874         ri.cw.p("hdr_len = sizeof(req->_hdr);")
1875         ri.cw.p("hdr = mnl_nlmsg_put_extra_header(nlh, hdr_len);")
1876         ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1877         ri.cw.nl()
1878
1879     if "request" in ri.op[ri.op_mode]:
1880         ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1881         ri.cw.nl()
1882         for _, attr in ri.struct["request"].member_list():
1883             attr.attr_put(ri, "req")
1884     ri.cw.nl()
1885
1886     ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1887     ri.cw.p('if (err < 0)')
1888     ri.cw.p('goto free_list;')
1889     ri.cw.nl()
1890
1891     ri.cw.p('return yds.first;')
1892     ri.cw.nl()
1893     ri.cw.p('free_list:')
1894     ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1895     ri.cw.p('return NULL;')
1896     ri.cw.block_end()
1897
1898
1899 def call_free(ri, direction, var):
1900     return f"{op_prefix(ri, direction)}_free({var});"
1901
1902
1903 def free_arg_name(direction):
1904     if direction:
1905         return direction_to_suffix[direction][1:]
1906     return 'obj'
1907
1908
1909 def print_alloc_wrapper(ri, direction):
1910     name = op_prefix(ri, direction)
1911     ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1912     ri.cw.block_start()
1913     ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1914     ri.cw.block_end()
1915
1916
1917 def print_free_prototype(ri, direction, suffix=';'):
1918     name = op_prefix(ri, direction)
1919     struct_name = name
1920     if ri.type_name_conflict:
1921         struct_name += '_'
1922     arg = free_arg_name(direction)
1923     ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1924
1925
1926 def _print_type(ri, direction, struct):
1927     suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1928     if not direction and ri.type_name_conflict:
1929         suffix += '_'
1930
1931     if ri.op_mode == 'dump':
1932         suffix += '_dump'
1933
1934     ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1935
1936     if ri.fixed_hdr:
1937         ri.cw.p(ri.fixed_hdr + ' _hdr;')
1938         ri.cw.nl()
1939
1940     meta_started = False
1941     for _, attr in struct.member_list():
1942         for type_filter in ['len', 'bit']:
1943             line = attr.presence_member(ri.ku_space, type_filter)
1944             if line:
1945                 if not meta_started:
1946                     ri.cw.block_start(line=f"struct")
1947                     meta_started = True
1948                 ri.cw.p(line)
1949     if meta_started:
1950         ri.cw.block_end(line='_present;')
1951         ri.cw.nl()
1952
1953     for arg in struct.inherited:
1954         ri.cw.p(f"__u32 {arg};")
1955
1956     for _, attr in struct.member_list():
1957         attr.struct_member(ri)
1958
1959     ri.cw.block_end(line=';')
1960     ri.cw.nl()
1961
1962
1963 def print_type(ri, direction):
1964     _print_type(ri, direction, ri.struct[direction])
1965
1966
1967 def print_type_full(ri, struct):
1968     _print_type(ri, "", struct)
1969
1970
1971 def print_type_helpers(ri, direction, deref=False):
1972     print_free_prototype(ri, direction)
1973     ri.cw.nl()
1974
1975     if ri.ku_space == 'user' and direction == 'request':
1976         for _, attr in ri.struct[direction].member_list():
1977             attr.setter(ri, ri.attr_set, direction, deref=deref)
1978     ri.cw.nl()
1979
1980
1981 def print_req_type_helpers(ri):
1982     if len(ri.struct["request"].attr_list) == 0:
1983         return
1984     print_alloc_wrapper(ri, "request")
1985     print_type_helpers(ri, "request")
1986
1987
1988 def print_rsp_type_helpers(ri):
1989     if 'reply' not in ri.op[ri.op_mode]:
1990         return
1991     print_type_helpers(ri, "reply")
1992
1993
1994 def print_parse_prototype(ri, direction, terminate=True):
1995     suffix = "_rsp" if direction == "reply" else "_req"
1996     term = ';' if terminate else ''
1997
1998     ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1999                           ['const struct nlattr **tb',
2000                            f"struct {ri.op.render_name}{suffix} *req"],
2001                           suffix=term)
2002
2003
2004 def print_req_type(ri):
2005     if len(ri.struct["request"].attr_list) == 0:
2006         return
2007     print_type(ri, "request")
2008
2009
2010 def print_req_free(ri):
2011     if 'request' not in ri.op[ri.op_mode]:
2012         return
2013     _free_type(ri, 'request', ri.struct['request'])
2014
2015
2016 def print_rsp_type(ri):
2017     if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2018         direction = 'reply'
2019     elif ri.op_mode == 'event':
2020         direction = 'reply'
2021     else:
2022         return
2023     print_type(ri, direction)
2024
2025
2026 def print_wrapped_type(ri):
2027     ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2028     if ri.op_mode == 'dump':
2029         ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2030     elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2031         ri.cw.p('__u16 family;')
2032         ri.cw.p('__u8 cmd;')
2033         ri.cw.p('struct ynl_ntf_base_type *next;')
2034         ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2035     ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2036     ri.cw.block_end(line=';')
2037     ri.cw.nl()
2038     print_free_prototype(ri, 'reply')
2039     ri.cw.nl()
2040
2041
2042 def _free_type_members_iter(ri, struct):
2043     for _, attr in struct.member_list():
2044         if attr.free_needs_iter():
2045             ri.cw.p('unsigned int i;')
2046             ri.cw.nl()
2047             break
2048
2049
2050 def _free_type_members(ri, var, struct, ref=''):
2051     for _, attr in struct.member_list():
2052         attr.free(ri, var, ref)
2053
2054
2055 def _free_type(ri, direction, struct):
2056     var = free_arg_name(direction)
2057
2058     print_free_prototype(ri, direction, suffix='')
2059     ri.cw.block_start()
2060     _free_type_members_iter(ri, struct)
2061     _free_type_members(ri, var, struct)
2062     if direction:
2063         ri.cw.p(f'free({var});')
2064     ri.cw.block_end()
2065     ri.cw.nl()
2066
2067
2068 def free_rsp_nested_prototype(ri):
2069         print_free_prototype(ri, "")
2070
2071
2072 def free_rsp_nested(ri, struct):
2073     _free_type(ri, "", struct)
2074
2075
2076 def print_rsp_free(ri):
2077     if 'reply' not in ri.op[ri.op_mode]:
2078         return
2079     _free_type(ri, 'reply', ri.struct['reply'])
2080
2081
2082 def print_dump_type_free(ri):
2083     sub_type = type_name(ri, 'reply')
2084
2085     print_free_prototype(ri, 'reply', suffix='')
2086     ri.cw.block_start()
2087     ri.cw.p(f"{sub_type} *next = rsp;")
2088     ri.cw.nl()
2089     ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2090     _free_type_members_iter(ri, ri.struct['reply'])
2091     ri.cw.p('rsp = next;')
2092     ri.cw.p('next = rsp->next;')
2093     ri.cw.nl()
2094
2095     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2096     ri.cw.p(f'free(rsp);')
2097     ri.cw.block_end()
2098     ri.cw.block_end()
2099     ri.cw.nl()
2100
2101
2102 def print_ntf_type_free(ri):
2103     print_free_prototype(ri, 'reply', suffix='')
2104     ri.cw.block_start()
2105     _free_type_members_iter(ri, ri.struct['reply'])
2106     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2107     ri.cw.p(f'free(rsp);')
2108     ri.cw.block_end()
2109     ri.cw.nl()
2110
2111
2112 def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2113     if terminate and ri and policy_should_be_static(struct.family):
2114         return
2115
2116     if terminate:
2117         prefix = 'extern '
2118     else:
2119         if ri and policy_should_be_static(struct.family):
2120             prefix = 'static '
2121         else:
2122             prefix = ''
2123
2124     suffix = ';' if terminate else ' = {'
2125
2126     max_attr = struct.attr_max_val
2127     if ri:
2128         name = ri.op.render_name
2129         if ri.op.dual_policy:
2130             name += '_' + ri.op_mode
2131     else:
2132         name = struct.render_name
2133     cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2134
2135
2136 def print_req_policy(cw, struct, ri=None):
2137     if ri and ri.op:
2138         cw.ifdef_block(ri.op.get('config-cond', None))
2139     print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2140     for _, arg in struct.member_list():
2141         arg.attr_policy(cw)
2142     cw.p("};")
2143     cw.ifdef_block(None)
2144     cw.nl()
2145
2146
2147 def kernel_can_gen_family_struct(family):
2148     return family.proto == 'genetlink'
2149
2150
2151 def policy_should_be_static(family):
2152     return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2153
2154
2155 def print_kernel_policy_ranges(family, cw):
2156     first = True
2157     for _, attr_set in family.attr_sets.items():
2158         if attr_set.subset_of:
2159             continue
2160
2161         for _, attr in attr_set.items():
2162             if not attr.request:
2163                 continue
2164             if 'full-range' not in attr.checks:
2165                 continue
2166
2167             if first:
2168                 cw.p('/* Integer value ranges */')
2169                 first = False
2170
2171             sign = '' if attr.type[0] == 'u' else '_signed'
2172             suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2173             cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2174             members = []
2175             if 'min' in attr.checks:
2176                 members.append(('min', str(attr.get_limit('min')) + suffix))
2177             if 'max' in attr.checks:
2178                 members.append(('max', str(attr.get_limit('max')) + suffix))
2179             cw.write_struct_init(members)
2180             cw.block_end(line=';')
2181             cw.nl()
2182
2183
2184 def print_kernel_op_table_fwd(family, cw, terminate):
2185     exported = not kernel_can_gen_family_struct(family)
2186
2187     if not terminate or exported:
2188         cw.p(f"/* Ops table for {family.name} */")
2189
2190         pol_to_struct = {'global': 'genl_small_ops',
2191                          'per-op': 'genl_ops',
2192                          'split': 'genl_split_ops'}
2193         struct_type = pol_to_struct[family.kernel_policy]
2194
2195         if not exported:
2196             cnt = ""
2197         elif family.kernel_policy == 'split':
2198             cnt = 0
2199             for op in family.ops.values():
2200                 if 'do' in op:
2201                     cnt += 1
2202                 if 'dump' in op:
2203                     cnt += 1
2204         else:
2205             cnt = len(family.ops)
2206
2207         qual = 'static const' if not exported else 'const'
2208         line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2209         if terminate:
2210             cw.p(f"extern {line};")
2211         else:
2212             cw.block_start(line=line + ' =')
2213
2214     if not terminate:
2215         return
2216
2217     cw.nl()
2218     for name in family.hooks['pre']['do']['list']:
2219         cw.write_func_prot('int', c_lower(name),
2220                            ['const struct genl_split_ops *ops',
2221                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2222     for name in family.hooks['post']['do']['list']:
2223         cw.write_func_prot('void', c_lower(name),
2224                            ['const struct genl_split_ops *ops',
2225                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2226     for name in family.hooks['pre']['dump']['list']:
2227         cw.write_func_prot('int', c_lower(name),
2228                            ['struct netlink_callback *cb'], suffix=';')
2229     for name in family.hooks['post']['dump']['list']:
2230         cw.write_func_prot('int', c_lower(name),
2231                            ['struct netlink_callback *cb'], suffix=';')
2232
2233     cw.nl()
2234
2235     for op_name, op in family.ops.items():
2236         if op.is_async:
2237             continue
2238
2239         if 'do' in op:
2240             name = c_lower(f"{family.name}-nl-{op_name}-doit")
2241             cw.write_func_prot('int', name,
2242                                ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2243
2244         if 'dump' in op:
2245             name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
2246             cw.write_func_prot('int', name,
2247                                ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2248     cw.nl()
2249
2250
2251 def print_kernel_op_table_hdr(family, cw):
2252     print_kernel_op_table_fwd(family, cw, terminate=True)
2253
2254
2255 def print_kernel_op_table(family, cw):
2256     print_kernel_op_table_fwd(family, cw, terminate=False)
2257     if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2258         for op_name, op in family.ops.items():
2259             if op.is_async:
2260                 continue
2261
2262             cw.ifdef_block(op.get('config-cond', None))
2263             cw.block_start()
2264             members = [('cmd', op.enum_name)]
2265             if 'dont-validate' in op:
2266                 members.append(('validate',
2267                                 ' | '.join([c_upper('genl-dont-validate-' + x)
2268                                             for x in op['dont-validate']])), )
2269             for op_mode in ['do', 'dump']:
2270                 if op_mode in op:
2271                     name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2272                     members.append((op_mode + 'it', name))
2273             if family.kernel_policy == 'per-op':
2274                 struct = Struct(family, op['attribute-set'],
2275                                 type_list=op['do']['request']['attributes'])
2276
2277                 name = c_lower(f"{family.name}-{op_name}-nl-policy")
2278                 members.append(('policy', name))
2279                 members.append(('maxattr', struct.attr_max_val.enum_name))
2280             if 'flags' in op:
2281                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2282             cw.write_struct_init(members)
2283             cw.block_end(line=',')
2284     elif family.kernel_policy == 'split':
2285         cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2286                     'dump': {'pre': 'start', 'post': 'done'}}
2287
2288         for op_name, op in family.ops.items():
2289             for op_mode in ['do', 'dump']:
2290                 if op.is_async or op_mode not in op:
2291                     continue
2292
2293                 cw.ifdef_block(op.get('config-cond', None))
2294                 cw.block_start()
2295                 members = [('cmd', op.enum_name)]
2296                 if 'dont-validate' in op:
2297                     dont_validate = []
2298                     for x in op['dont-validate']:
2299                         if op_mode == 'do' and x in ['dump', 'dump-strict']:
2300                             continue
2301                         if op_mode == "dump" and x == 'strict':
2302                             continue
2303                         dont_validate.append(x)
2304
2305                     if dont_validate:
2306                         members.append(('validate',
2307                                         ' | '.join([c_upper('genl-dont-validate-' + x)
2308                                                     for x in dont_validate])), )
2309                 name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2310                 if 'pre' in op[op_mode]:
2311                     members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2312                 members.append((op_mode + 'it', name))
2313                 if 'post' in op[op_mode]:
2314                     members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2315                 if 'request' in op[op_mode]:
2316                     struct = Struct(family, op['attribute-set'],
2317                                     type_list=op[op_mode]['request']['attributes'])
2318
2319                     if op.dual_policy:
2320                         name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2321                     else:
2322                         name = c_lower(f"{family.name}-{op_name}-nl-policy")
2323                     members.append(('policy', name))
2324                     members.append(('maxattr', struct.attr_max_val.enum_name))
2325                 flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2326                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2327                 cw.write_struct_init(members)
2328                 cw.block_end(line=',')
2329     cw.ifdef_block(None)
2330
2331     cw.block_end(line=';')
2332     cw.nl()
2333
2334
2335 def print_kernel_mcgrp_hdr(family, cw):
2336     if not family.mcgrps['list']:
2337         return
2338
2339     cw.block_start('enum')
2340     for grp in family.mcgrps['list']:
2341         grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2342         cw.p(grp_id)
2343     cw.block_end(';')
2344     cw.nl()
2345
2346
2347 def print_kernel_mcgrp_src(family, cw):
2348     if not family.mcgrps['list']:
2349         return
2350
2351     cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2352     for grp in family.mcgrps['list']:
2353         name = grp['name']
2354         grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2355         cw.p('[' + grp_id + '] = { "' + name + '", },')
2356     cw.block_end(';')
2357     cw.nl()
2358
2359
2360 def print_kernel_family_struct_hdr(family, cw):
2361     if not kernel_can_gen_family_struct(family):
2362         return
2363
2364     cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2365     cw.nl()
2366
2367
2368 def print_kernel_family_struct_src(family, cw):
2369     if not kernel_can_gen_family_struct(family):
2370         return
2371
2372     cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2373     cw.p('.name\t\t= ' + family.fam_key + ',')
2374     cw.p('.version\t= ' + family.ver_key + ',')
2375     cw.p('.netnsok\t= true,')
2376     cw.p('.parallel_ops\t= true,')
2377     cw.p('.module\t\t= THIS_MODULE,')
2378     if family.kernel_policy == 'per-op':
2379         cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2380         cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2381     elif family.kernel_policy == 'split':
2382         cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2383         cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2384     if family.mcgrps['list']:
2385         cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2386         cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2387     cw.block_end(';')
2388
2389
2390 def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2391     start_line = 'enum'
2392     if enum_name in obj:
2393         if obj[enum_name]:
2394             start_line = 'enum ' + c_lower(obj[enum_name])
2395     elif ckey and ckey in obj:
2396         start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2397     cw.block_start(line=start_line)
2398
2399
2400 def render_uapi(family, cw):
2401     hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2402     cw.p('#ifndef ' + hdr_prot)
2403     cw.p('#define ' + hdr_prot)
2404     cw.nl()
2405
2406     defines = [(family.fam_key, family["name"]),
2407                (family.ver_key, family.get('version', 1))]
2408     cw.writes_defines(defines)
2409     cw.nl()
2410
2411     defines = []
2412     for const in family['definitions']:
2413         if const['type'] != 'const':
2414             cw.writes_defines(defines)
2415             defines = []
2416             cw.nl()
2417
2418         # Write kdoc for enum and flags (one day maybe also structs)
2419         if const['type'] == 'enum' or const['type'] == 'flags':
2420             enum = family.consts[const['name']]
2421
2422             if enum.has_doc():
2423                 cw.p('/**')
2424                 doc = ''
2425                 if 'doc' in enum:
2426                     doc = ' - ' + enum['doc']
2427                 cw.write_doc_line(enum.enum_name + doc)
2428                 for entry in enum.entries.values():
2429                     if entry.has_doc():
2430                         doc = '@' + entry.c_name + ': ' + entry['doc']
2431                         cw.write_doc_line(doc)
2432                 cw.p(' */')
2433
2434             uapi_enum_start(family, cw, const, 'name')
2435             name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2436             for entry in enum.entries.values():
2437                 suffix = ','
2438                 if entry.value_change:
2439                     suffix = f" = {entry.user_value()}" + suffix
2440                 cw.p(entry.c_name + suffix)
2441
2442             if const.get('render-max', False):
2443                 cw.nl()
2444                 cw.p('/* private: */')
2445                 if const['type'] == 'flags':
2446                     max_name = c_upper(name_pfx + 'mask')
2447                     max_val = f' = {enum.get_mask()},'
2448                     cw.p(max_name + max_val)
2449                 else:
2450                     max_name = c_upper(name_pfx + 'max')
2451                     cw.p('__' + max_name + ',')
2452                     cw.p(max_name + ' = (__' + max_name + ' - 1)')
2453             cw.block_end(line=';')
2454             cw.nl()
2455         elif const['type'] == 'const':
2456             defines.append([c_upper(family.get('c-define-name',
2457                                                f"{family.name}-{const['name']}")),
2458                             const['value']])
2459
2460     if defines:
2461         cw.writes_defines(defines)
2462         cw.nl()
2463
2464     max_by_define = family.get('max-by-define', False)
2465
2466     for _, attr_set in family.attr_sets.items():
2467         if attr_set.subset_of:
2468             continue
2469
2470         max_value = f"({attr_set.cnt_name} - 1)"
2471
2472         val = 0
2473         uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2474         for _, attr in attr_set.items():
2475             suffix = ','
2476             if attr.value != val:
2477                 suffix = f" = {attr.value},"
2478                 val = attr.value
2479             val += 1
2480             cw.p(attr.enum_name + suffix)
2481         cw.nl()
2482         cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2483         if not max_by_define:
2484             cw.p(f"{attr_set.max_name} = {max_value}")
2485         cw.block_end(line=';')
2486         if max_by_define:
2487             cw.p(f"#define {attr_set.max_name} {max_value}")
2488         cw.nl()
2489
2490     # Commands
2491     separate_ntf = 'async-prefix' in family['operations']
2492
2493     max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2494     cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2495     max_value = f"({cnt_name} - 1)"
2496
2497     uapi_enum_start(family, cw, family['operations'], 'enum-name')
2498     val = 0
2499     for op in family.msgs.values():
2500         if separate_ntf and ('notify' in op or 'event' in op):
2501             continue
2502
2503         suffix = ','
2504         if op.value != val:
2505             suffix = f" = {op.value},"
2506             val = op.value
2507         cw.p(op.enum_name + suffix)
2508         val += 1
2509     cw.nl()
2510     cw.p(cnt_name + ('' if max_by_define else ','))
2511     if not max_by_define:
2512         cw.p(f"{max_name} = {max_value}")
2513     cw.block_end(line=';')
2514     if max_by_define:
2515         cw.p(f"#define {max_name} {max_value}")
2516     cw.nl()
2517
2518     if separate_ntf:
2519         uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2520         for op in family.msgs.values():
2521             if separate_ntf and not ('notify' in op or 'event' in op):
2522                 continue
2523
2524             suffix = ','
2525             if 'value' in op:
2526                 suffix = f" = {op['value']},"
2527             cw.p(op.enum_name + suffix)
2528         cw.block_end(line=';')
2529         cw.nl()
2530
2531     # Multicast
2532     defines = []
2533     for grp in family.mcgrps['list']:
2534         name = grp['name']
2535         defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2536                         f'{name}'])
2537     cw.nl()
2538     if defines:
2539         cw.writes_defines(defines)
2540         cw.nl()
2541
2542     cw.p(f'#endif /* {hdr_prot} */')
2543
2544
2545 def _render_user_ntf_entry(ri, op):
2546     ri.cw.block_start(line=f"[{op.enum_name}] = ")
2547     ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2548     ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2549     ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2550     ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2551     ri.cw.block_end(line=',')
2552
2553
2554 def render_user_family(family, cw, prototype):
2555     symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2556     if prototype:
2557         cw.p(f'extern {symbol};')
2558         return
2559
2560     if family.ntfs:
2561         cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2562         for ntf_op_name, ntf_op in family.ntfs.items():
2563             if 'notify' in ntf_op:
2564                 op = family.ops[ntf_op['notify']]
2565                 ri = RenderInfo(cw, family, "user", op, "notify")
2566             elif 'event' in ntf_op:
2567                 ri = RenderInfo(cw, family, "user", ntf_op, "event")
2568             else:
2569                 raise Exception('Invalid notification ' + ntf_op_name)
2570             _render_user_ntf_entry(ri, ntf_op)
2571         for op_name, op in family.ops.items():
2572             if 'event' not in op:
2573                 continue
2574             ri = RenderInfo(cw, family, "user", op, "event")
2575             _render_user_ntf_entry(ri, op)
2576         cw.block_end(line=";")
2577         cw.nl()
2578
2579     cw.block_start(f'{symbol} = ')
2580     cw.p(f'.name\t\t= "{family.c_name}",')
2581     if family.fixed_header:
2582         cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2583     else:
2584         cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2585     if family.ntfs:
2586         cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2587         cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2588     cw.block_end(line=';')
2589
2590
2591 def family_contains_bitfield32(family):
2592     for _, attr_set in family.attr_sets.items():
2593         if attr_set.subset_of:
2594             continue
2595         for _, attr in attr_set.items():
2596             if attr.type == "bitfield32":
2597                 return True
2598     return False
2599
2600
2601 def find_kernel_root(full_path):
2602     sub_path = ''
2603     while True:
2604         sub_path = os.path.join(os.path.basename(full_path), sub_path)
2605         full_path = os.path.dirname(full_path)
2606         maintainers = os.path.join(full_path, "MAINTAINERS")
2607         if os.path.exists(maintainers):
2608             return full_path, sub_path[:-1]
2609
2610
2611 def main():
2612     parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2613     parser.add_argument('--mode', dest='mode', type=str, required=True)
2614     parser.add_argument('--spec', dest='spec', type=str, required=True)
2615     parser.add_argument('--header', dest='header', action='store_true', default=None)
2616     parser.add_argument('--source', dest='header', action='store_false')
2617     parser.add_argument('--user-header', nargs='+', default=[])
2618     parser.add_argument('--cmp-out', action='store_true', default=None,
2619                         help='Do not overwrite the output file if the new output is identical to the old')
2620     parser.add_argument('--exclude-op', action='append', default=[])
2621     parser.add_argument('-o', dest='out_file', type=str, default=None)
2622     args = parser.parse_args()
2623
2624     if args.header is None:
2625         parser.error("--header or --source is required")
2626
2627     exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2628
2629     try:
2630         parsed = Family(args.spec, exclude_ops)
2631         if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2632             print('Spec license:', parsed.license)
2633             print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2634             os.sys.exit(1)
2635     except yaml.YAMLError as exc:
2636         print(exc)
2637         os.sys.exit(1)
2638         return
2639
2640     supported_models = ['unified']
2641     if args.mode in ['user', 'kernel']:
2642         supported_models += ['directional']
2643     if parsed.msg_id_model not in supported_models:
2644         print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2645         os.sys.exit(1)
2646
2647     cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2648
2649     _, spec_kernel = find_kernel_root(args.spec)
2650     if args.mode == 'uapi' or args.header:
2651         cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2652     else:
2653         cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2654     cw.p("/* Do not edit directly, auto-generated from: */")
2655     cw.p(f"/*\t{spec_kernel} */")
2656     cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2657     if args.exclude_op or args.user_header:
2658         line = ''
2659         line += ' --user-header '.join([''] + args.user_header)
2660         line += ' --exclude-op '.join([''] + args.exclude_op)
2661         cw.p(f'/* YNL-ARG{line} */')
2662     cw.nl()
2663
2664     if args.mode == 'uapi':
2665         render_uapi(parsed, cw)
2666         return
2667
2668     hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2669     if args.header:
2670         cw.p('#ifndef ' + hdr_prot)
2671         cw.p('#define ' + hdr_prot)
2672         cw.nl()
2673
2674     if args.mode == 'kernel':
2675         cw.p('#include <net/netlink.h>')
2676         cw.p('#include <net/genetlink.h>')
2677         cw.nl()
2678         if not args.header:
2679             if args.out_file:
2680                 cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2681             cw.nl()
2682         headers = ['uapi/' + parsed.uapi_header]
2683     else:
2684         cw.p('#include <stdlib.h>')
2685         cw.p('#include <string.h>')
2686         if args.header:
2687             cw.p('#include <linux/types.h>')
2688             if family_contains_bitfield32(parsed):
2689                 cw.p('#include <linux/netlink.h>')
2690         else:
2691             cw.p(f'#include "{parsed.name}-user.h"')
2692             cw.p('#include "ynl.h"')
2693         headers = [parsed.uapi_header]
2694     for definition in parsed['definitions']:
2695         if 'header' in definition:
2696             headers.append(definition['header'])
2697     for one in headers:
2698         cw.p(f"#include <{one}>")
2699     cw.nl()
2700
2701     if args.mode == "user":
2702         if not args.header:
2703             cw.p("#include <libmnl/libmnl.h>")
2704             cw.p("#include <linux/genetlink.h>")
2705             cw.nl()
2706             for one in args.user_header:
2707                 cw.p(f'#include "{one}"')
2708         else:
2709             cw.p('struct ynl_sock;')
2710             cw.nl()
2711             render_user_family(parsed, cw, True)
2712         cw.nl()
2713
2714     if args.mode == "kernel":
2715         if args.header:
2716             for _, struct in sorted(parsed.pure_nested_structs.items()):
2717                 if struct.request:
2718                     cw.p('/* Common nested types */')
2719                     break
2720             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2721                 if struct.request:
2722                     print_req_policy_fwd(cw, struct)
2723             cw.nl()
2724
2725             if parsed.kernel_policy == 'global':
2726                 cw.p(f"/* Global operation policy for {parsed.name} */")
2727
2728                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2729                 print_req_policy_fwd(cw, struct)
2730                 cw.nl()
2731
2732             if parsed.kernel_policy in {'per-op', 'split'}:
2733                 for op_name, op in parsed.ops.items():
2734                     if 'do' in op and 'event' not in op:
2735                         ri = RenderInfo(cw, parsed, args.mode, op, "do")
2736                         print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2737                         cw.nl()
2738
2739             print_kernel_op_table_hdr(parsed, cw)
2740             print_kernel_mcgrp_hdr(parsed, cw)
2741             print_kernel_family_struct_hdr(parsed, cw)
2742         else:
2743             print_kernel_policy_ranges(parsed, cw)
2744
2745             for _, struct in sorted(parsed.pure_nested_structs.items()):
2746                 if struct.request:
2747                     cw.p('/* Common nested types */')
2748                     break
2749             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2750                 if struct.request:
2751                     print_req_policy(cw, struct)
2752             cw.nl()
2753
2754             if parsed.kernel_policy == 'global':
2755                 cw.p(f"/* Global operation policy for {parsed.name} */")
2756
2757                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2758                 print_req_policy(cw, struct)
2759                 cw.nl()
2760
2761             for op_name, op in parsed.ops.items():
2762                 if parsed.kernel_policy in {'per-op', 'split'}:
2763                     for op_mode in ['do', 'dump']:
2764                         if op_mode in op and 'request' in op[op_mode]:
2765                             cw.p(f"/* {op.enum_name} - {op_mode} */")
2766                             ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2767                             print_req_policy(cw, ri.struct['request'], ri=ri)
2768                             cw.nl()
2769
2770             print_kernel_op_table(parsed, cw)
2771             print_kernel_mcgrp_src(parsed, cw)
2772             print_kernel_family_struct_src(parsed, cw)
2773
2774     if args.mode == "user":
2775         if args.header:
2776             cw.p('/* Enums */')
2777             put_op_name_fwd(parsed, cw)
2778
2779             for name, const in parsed.consts.items():
2780                 if isinstance(const, EnumSet):
2781                     put_enum_to_str_fwd(parsed, cw, const)
2782             cw.nl()
2783
2784             cw.p('/* Common nested types */')
2785             for attr_set, struct in parsed.pure_nested_structs.items():
2786                 ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2787                 print_type_full(ri, struct)
2788
2789             for op_name, op in parsed.ops.items():
2790                 cw.p(f"/* ============== {op.enum_name} ============== */")
2791
2792                 if 'do' in op and 'event' not in op:
2793                     cw.p(f"/* {op.enum_name} - do */")
2794                     ri = RenderInfo(cw, parsed, args.mode, op, "do")
2795                     print_req_type(ri)
2796                     print_req_type_helpers(ri)
2797                     cw.nl()
2798                     print_rsp_type(ri)
2799                     print_rsp_type_helpers(ri)
2800                     cw.nl()
2801                     print_req_prototype(ri)
2802                     cw.nl()
2803
2804                 if 'dump' in op:
2805                     cw.p(f"/* {op.enum_name} - dump */")
2806                     ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2807                     print_req_type(ri)
2808                     print_req_type_helpers(ri)
2809                     if not ri.type_consistent:
2810                         print_rsp_type(ri)
2811                     print_wrapped_type(ri)
2812                     print_dump_prototype(ri)
2813                     cw.nl()
2814
2815                 if op.has_ntf:
2816                     cw.p(f"/* {op.enum_name} - notify */")
2817                     ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2818                     if not ri.type_consistent:
2819                         raise Exception(f'Only notifications with consistent types supported ({op.name})')
2820                     print_wrapped_type(ri)
2821
2822             for op_name, op in parsed.ntfs.items():
2823                 if 'event' in op:
2824                     ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2825                     cw.p(f"/* {op.enum_name} - event */")
2826                     print_rsp_type(ri)
2827                     cw.nl()
2828                     print_wrapped_type(ri)
2829             cw.nl()
2830         else:
2831             cw.p('/* Enums */')
2832             put_op_name(parsed, cw)
2833
2834             for name, const in parsed.consts.items():
2835                 if isinstance(const, EnumSet):
2836                     put_enum_to_str(parsed, cw, const)
2837             cw.nl()
2838
2839             has_recursive_nests = False
2840             cw.p('/* Policies */')
2841             for struct in parsed.pure_nested_structs.values():
2842                 if struct.recursive:
2843                     put_typol_fwd(cw, struct)
2844                     has_recursive_nests = True
2845             if has_recursive_nests:
2846                 cw.nl()
2847             for name in parsed.pure_nested_structs:
2848                 struct = Struct(parsed, name)
2849                 put_typol(cw, struct)
2850             for name in parsed.root_sets:
2851                 struct = Struct(parsed, name)
2852                 put_typol(cw, struct)
2853
2854             cw.p('/* Common nested types */')
2855             if has_recursive_nests:
2856                 for attr_set, struct in parsed.pure_nested_structs.items():
2857                     ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2858                     free_rsp_nested_prototype(ri)
2859                     if struct.request:
2860                         put_req_nested_prototype(ri, struct)
2861                     if struct.reply:
2862                         parse_rsp_nested_prototype(ri, struct)
2863                 cw.nl()
2864             for attr_set, struct in parsed.pure_nested_structs.items():
2865                 ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2866
2867                 free_rsp_nested(ri, struct)
2868                 if struct.request:
2869                     put_req_nested(ri, struct)
2870                 if struct.reply:
2871                     parse_rsp_nested(ri, struct)
2872
2873             for op_name, op in parsed.ops.items():
2874                 cw.p(f"/* ============== {op.enum_name} ============== */")
2875                 if 'do' in op and 'event' not in op:
2876                     cw.p(f"/* {op.enum_name} - do */")
2877                     ri = RenderInfo(cw, parsed, args.mode, op, "do")
2878                     print_req_free(ri)
2879                     print_rsp_free(ri)
2880                     parse_rsp_msg(ri)
2881                     print_req(ri)
2882                     cw.nl()
2883
2884                 if 'dump' in op:
2885                     cw.p(f"/* {op.enum_name} - dump */")
2886                     ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2887                     if not ri.type_consistent:
2888                         parse_rsp_msg(ri, deref=True)
2889                     print_req_free(ri)
2890                     print_dump_type_free(ri)
2891                     print_dump(ri)
2892                     cw.nl()
2893
2894                 if op.has_ntf:
2895                     cw.p(f"/* {op.enum_name} - notify */")
2896                     ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2897                     if not ri.type_consistent:
2898                         raise Exception(f'Only notifications with consistent types supported ({op.name})')
2899                     print_ntf_type_free(ri)
2900
2901             for op_name, op in parsed.ntfs.items():
2902                 if 'event' in op:
2903                     cw.p(f"/* {op.enum_name} - event */")
2904
2905                     ri = RenderInfo(cw, parsed, args.mode, op, "do")
2906                     parse_rsp_msg(ri)
2907
2908                     ri = RenderInfo(cw, parsed, args.mode, op, "event")
2909                     print_ntf_type_free(ri)
2910             cw.nl()
2911             render_user_family(parsed, cw, False)
2912
2913     if args.header:
2914         cw.p(f'#endif /* {hdr_prot} */')
2915
2916
2917 if __name__ == "__main__":
2918     main()
This page took 0.195673 seconds and 4 git commands to generate.