diff options
Diffstat (limited to 'tools/net/ynl/lib')
-rw-r--r-- | tools/net/ynl/lib/.gitignore | 2 | ||||
-rw-r--r-- | tools/net/ynl/lib/Makefile | 10 | ||||
-rw-r--r-- | tools/net/ynl/lib/__init__.py | 8 | ||||
-rw-r--r-- | tools/net/ynl/lib/nlspec.py | 606 | ||||
-rw-r--r-- | tools/net/ynl/lib/ynl-priv.h | 369 | ||||
-rw-r--r-- | tools/net/ynl/lib/ynl.c | 409 | ||||
-rw-r--r-- | tools/net/ynl/lib/ynl.h | 19 | ||||
-rw-r--r-- | tools/net/ynl/lib/ynl.py | 824 |
8 files changed, 593 insertions, 1654 deletions
diff --git a/tools/net/ynl/lib/.gitignore b/tools/net/ynl/lib/.gitignore index c18dd8d83cee..a4383358ec72 100644 --- a/tools/net/ynl/lib/.gitignore +++ b/tools/net/ynl/lib/.gitignore @@ -1 +1 @@ -__pycache__/ +*.d diff --git a/tools/net/ynl/lib/Makefile b/tools/net/ynl/lib/Makefile index d2e50fd0a52d..4b2b98704ff9 100644 --- a/tools/net/ynl/lib/Makefile +++ b/tools/net/ynl/lib/Makefile @@ -1,7 +1,7 @@ # SPDX-License-Identifier: GPL-2.0 CC=gcc -CFLAGS=-std=gnu11 -O2 -W -Wall -Wextra -Wno-unused-parameter -Wshadow +CFLAGS += -std=gnu11 -O2 -W -Wall -Wextra -Wno-unused-parameter -Wshadow ifeq ("$(DEBUG)","1") CFLAGS += -g -fsanitize=address -fsanitize=leak -static-libasan endif @@ -14,15 +14,17 @@ include $(wildcard *.d) all: ynl.a ynl.a: $(OBJS) - ar rcs $@ $(OBJS) + @echo -e "\tAR $@" + @ar rcs $@ $(OBJS) + clean: rm -f *.o *.d *~ -hardclean: clean +distclean: clean rm -f *.a %.o: %.c $(COMPILE.c) -MMD -c -o $@ $< -.PHONY: all clean +.PHONY: all clean distclean .DEFAULT_GOAL=all diff --git a/tools/net/ynl/lib/__init__.py b/tools/net/ynl/lib/__init__.py deleted file mode 100644 index f7eaa07783e7..000000000000 --- a/tools/net/ynl/lib/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause - -from .nlspec import SpecAttr, SpecAttrSet, SpecEnumEntry, SpecEnumSet, \ - SpecFamily, SpecOperation -from .ynl import YnlFamily, Netlink - -__all__ = ["SpecAttr", "SpecAttrSet", "SpecEnumEntry", "SpecEnumSet", - "SpecFamily", "SpecOperation", "YnlFamily", "Netlink"] diff --git a/tools/net/ynl/lib/nlspec.py b/tools/net/ynl/lib/nlspec.py deleted file mode 100644 index 44f13e383e8a..000000000000 --- a/tools/net/ynl/lib/nlspec.py +++ /dev/null @@ -1,606 +0,0 @@ -# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause - -import collections -import importlib -import os -import yaml - - -# To be loaded dynamically as needed -jsonschema = None - - -class SpecElement: - """Netlink spec element. - - Abstract element of the Netlink spec. Implements the dictionary interface - for access to the raw spec. Supports iterative resolution of dependencies - across elements and class inheritance levels. The elements of the spec - may refer to each other, and although loops should be very rare, having - to maintain correct ordering of instantiation is painful, so the resolve() - method should be used to perform parts of init which require access to - other parts of the spec. - - Attributes: - yaml raw spec as loaded from the spec file - family back reference to the full family - - name name of the entity as listed in the spec (optional) - ident_name name which can be safely used as identifier in code (optional) - """ - def __init__(self, family, yaml): - self.yaml = yaml - self.family = family - - if 'name' in self.yaml: - self.name = self.yaml['name'] - self.ident_name = self.name.replace('-', '_') - - self._super_resolved = False - family.add_unresolved(self) - - def __getitem__(self, key): - return self.yaml[key] - - def __contains__(self, key): - return key in self.yaml - - def get(self, key, default=None): - return self.yaml.get(key, default) - - def resolve_up(self, up): - if not self._super_resolved: - up.resolve() - self._super_resolved = True - - def resolve(self): - pass - - -class SpecEnumEntry(SpecElement): - """ Entry within an enum declared in the Netlink spec. - - Attributes: - doc documentation string - enum_set back reference to the enum - value numerical value of this enum (use accessors in most situations!) - - Methods: - raw_value raw value, i.e. the id in the enum, unlike user value which is a mask for flags - user_value user value, same as raw value for enums, for flags it's the mask - """ - def __init__(self, enum_set, yaml, prev, value_start): - if isinstance(yaml, str): - yaml = {'name': yaml} - super().__init__(enum_set.family, yaml) - - self.doc = yaml.get('doc', '') - self.enum_set = enum_set - - if 'value' in yaml: - self.value = yaml['value'] - elif prev: - self.value = prev.value + 1 - else: - self.value = value_start - - def has_doc(self): - return bool(self.doc) - - def raw_value(self): - return self.value - - def user_value(self, as_flags=None): - if self.enum_set['type'] == 'flags' or as_flags: - return 1 << self.value - else: - return self.value - - -class SpecEnumSet(SpecElement): - """ Enum type - - Represents an enumeration (list of numerical constants) - as declared in the "definitions" section of the spec. - - Attributes: - type enum or flags - entries entries by name - entries_by_val entries by value - Methods: - get_mask for flags compute the mask of all defined values - """ - def __init__(self, family, yaml): - super().__init__(family, yaml) - - self.type = yaml['type'] - - prev_entry = None - value_start = self.yaml.get('value-start', 0) - self.entries = dict() - self.entries_by_val = dict() - for entry in self.yaml['entries']: - e = self.new_entry(entry, prev_entry, value_start) - self.entries[e.name] = e - self.entries_by_val[e.raw_value()] = e - prev_entry = e - - def new_entry(self, entry, prev_entry, value_start): - return SpecEnumEntry(self, entry, prev_entry, value_start) - - def has_doc(self): - if 'doc' in self.yaml: - return True - for entry in self.entries.values(): - if entry.has_doc(): - return True - return False - - def get_mask(self, as_flags=None): - mask = 0 - for e in self.entries.values(): - mask += e.user_value(as_flags) - return mask - - -class SpecAttr(SpecElement): - """ Single Netlink atttribute type - - Represents a single attribute type within an attr space. - - Attributes: - type string, attribute type - value numerical ID when serialized - attr_set Attribute Set containing this attr - is_multi bool, attr may repeat multiple times - struct_name string, name of struct definition - sub_type string, name of sub type - len integer, optional byte length of binary types - display_hint string, hint to help choose format specifier - when displaying the value - sub_message string, name of sub message type - selector string, name of attribute used to select - sub-message type - - is_auto_scalar bool, attr is a variable-size scalar - """ - def __init__(self, family, attr_set, yaml, value): - super().__init__(family, yaml) - - self.type = yaml['type'] - self.value = value - self.attr_set = attr_set - self.is_multi = yaml.get('multi-attr', False) - self.struct_name = yaml.get('struct') - self.sub_type = yaml.get('sub-type') - self.byte_order = yaml.get('byte-order') - self.len = yaml.get('len') - self.display_hint = yaml.get('display-hint') - self.sub_message = yaml.get('sub-message') - self.selector = yaml.get('selector') - - self.is_auto_scalar = self.type == "sint" or self.type == "uint" - - -class SpecAttrSet(SpecElement): - """ Netlink Attribute Set class. - - Represents a ID space of attributes within Netlink. - - Note that unlike other elements, which expose contents of the raw spec - via the dictionary interface Attribute Set exposes attributes by name. - - Attributes: - attrs ordered dict of all attributes (indexed by name) - attrs_by_val ordered dict of all attributes (indexed by value) - subset_of parent set if this is a subset, otherwise None - """ - def __init__(self, family, yaml): - super().__init__(family, yaml) - - self.subset_of = self.yaml.get('subset-of', None) - - self.attrs = collections.OrderedDict() - self.attrs_by_val = collections.OrderedDict() - - if self.subset_of is None: - val = 1 - for elem in self.yaml['attributes']: - if 'value' in elem: - val = elem['value'] - - attr = self.new_attr(elem, val) - self.attrs[attr.name] = attr - self.attrs_by_val[attr.value] = attr - val += 1 - else: - real_set = family.attr_sets[self.subset_of] - for elem in self.yaml['attributes']: - attr = real_set[elem['name']] - self.attrs[attr.name] = attr - self.attrs_by_val[attr.value] = attr - - def new_attr(self, elem, value): - return SpecAttr(self.family, self, elem, value) - - def __getitem__(self, key): - return self.attrs[key] - - def __contains__(self, key): - return key in self.attrs - - def __iter__(self): - yield from self.attrs - - def items(self): - return self.attrs.items() - - -class SpecStructMember(SpecElement): - """Struct member attribute - - Represents a single struct member attribute. - - Attributes: - type string, type of the member attribute - byte_order string or None for native byte order - enum string, name of the enum definition - len integer, optional byte length of binary types - display_hint string, hint to help choose format specifier - when displaying the value - """ - def __init__(self, family, yaml): - super().__init__(family, yaml) - self.type = yaml['type'] - self.byte_order = yaml.get('byte-order') - self.enum = yaml.get('enum') - self.len = yaml.get('len') - self.display_hint = yaml.get('display-hint') - - -class SpecStruct(SpecElement): - """Netlink struct type - - Represents a C struct definition. - - Attributes: - members ordered list of struct members - """ - def __init__(self, family, yaml): - super().__init__(family, yaml) - - self.members = [] - for member in yaml.get('members', []): - self.members.append(self.new_member(family, member)) - - def new_member(self, family, elem): - return SpecStructMember(family, elem) - - def __iter__(self): - yield from self.members - - def items(self): - return self.members.items() - - -class SpecSubMessage(SpecElement): - """ Netlink sub-message definition - - Represents a set of sub-message formats for polymorphic nlattrs - that contain type-specific sub messages. - - Attributes: - name string, name of sub-message definition - formats dict of sub-message formats indexed by match value - """ - def __init__(self, family, yaml): - super().__init__(family, yaml) - - self.formats = collections.OrderedDict() - for elem in self.yaml['formats']: - format = self.new_format(family, elem) - self.formats[format.value] = format - - def new_format(self, family, format): - return SpecSubMessageFormat(family, format) - - -class SpecSubMessageFormat(SpecElement): - """ Netlink sub-message definition - - Represents a set of sub-message formats for polymorphic nlattrs - that contain type-specific sub messages. - - Attributes: - value attribute value to match against type selector - fixed_header string, name of fixed header, or None - attr_set string, name of attribute set, or None - """ - def __init__(self, family, yaml): - super().__init__(family, yaml) - - self.value = yaml.get('value') - self.fixed_header = yaml.get('fixed-header') - self.attr_set = yaml.get('attribute-set') - - -class SpecOperation(SpecElement): - """Netlink Operation - - Information about a single Netlink operation. - - Attributes: - value numerical ID when serialized, None if req/rsp values differ - - req_value numerical ID when serialized, user -> kernel - rsp_value numerical ID when serialized, user <- kernel - is_call bool, whether the operation is a call - is_async bool, whether the operation is a notification - is_resv bool, whether the operation does not exist (it's just a reserved ID) - attr_set attribute set name - fixed_header string, optional name of fixed header struct - - yaml raw spec as loaded from the spec file - """ - def __init__(self, family, yaml, req_value, rsp_value): - super().__init__(family, yaml) - - self.value = req_value if req_value == rsp_value else None - self.req_value = req_value - self.rsp_value = rsp_value - - self.is_call = 'do' in yaml or 'dump' in yaml - self.is_async = 'notify' in yaml or 'event' in yaml - self.is_resv = not self.is_async and not self.is_call - self.fixed_header = self.yaml.get('fixed-header', family.fixed_header) - - # Added by resolve: - self.attr_set = None - delattr(self, "attr_set") - - def resolve(self): - self.resolve_up(super()) - - if 'attribute-set' in self.yaml: - attr_set_name = self.yaml['attribute-set'] - elif 'notify' in self.yaml: - msg = self.family.msgs[self.yaml['notify']] - attr_set_name = msg['attribute-set'] - elif self.is_resv: - attr_set_name = '' - else: - raise Exception(f"Can't resolve attribute set for op '{self.name}'") - if attr_set_name: - self.attr_set = self.family.attr_sets[attr_set_name] - - -class SpecMcastGroup(SpecElement): - """Netlink Multicast Group - - Information about a multicast group. - - Value is only used for classic netlink families that use the - netlink-raw schema. Genetlink families use dynamic ID allocation - where the ids of multicast groups get resolved at runtime. Value - will be None for genetlink families. - - Attributes: - name name of the mulitcast group - value integer id of this multicast group for netlink-raw or None - yaml raw spec as loaded from the spec file - """ - def __init__(self, family, yaml): - super().__init__(family, yaml) - self.value = self.yaml.get('value') - - -class SpecFamily(SpecElement): - """ Netlink Family Spec class. - - Netlink family information loaded from a spec (e.g. in YAML). - Takes care of unfolding implicit information which can be skipped - in the spec itself for brevity. - - The class can be used like a dictionary to access the raw spec - elements but that's usually a bad idea. - - Attributes: - proto protocol type (e.g. genetlink) - msg_id_model enum-model for operations (unified, directional etc.) - license spec license (loaded from an SPDX tag on the spec) - - attr_sets dict of attribute sets - msgs dict of all messages (index by name) - sub_msgs dict of all sub messages (index by name) - ops dict of all valid requests / responses - ntfs dict of all async events - consts dict of all constants/enums - fixed_header string, optional name of family default fixed header struct - mcast_groups dict of all multicast groups (index by name) - """ - def __init__(self, spec_path, schema_path=None, exclude_ops=None): - with open(spec_path, "r") as stream: - prefix = '# SPDX-License-Identifier: ' - first = stream.readline().strip() - if not first.startswith(prefix): - raise Exception('SPDX license tag required in the spec') - self.license = first[len(prefix):] - - stream.seek(0) - spec = yaml.safe_load(stream) - - self._resolution_list = [] - - super().__init__(self, spec) - - self._exclude_ops = exclude_ops if exclude_ops else [] - - self.proto = self.yaml.get('protocol', 'genetlink') - self.msg_id_model = self.yaml['operations'].get('enum-model', 'unified') - - if schema_path is None: - schema_path = os.path.dirname(os.path.dirname(spec_path)) + f'/{self.proto}.yaml' - if schema_path: - global jsonschema - - with open(schema_path, "r") as stream: - schema = yaml.safe_load(stream) - - if jsonschema is None: - jsonschema = importlib.import_module("jsonschema") - - jsonschema.validate(self.yaml, schema) - - self.attr_sets = collections.OrderedDict() - self.sub_msgs = collections.OrderedDict() - self.msgs = collections.OrderedDict() - self.req_by_value = collections.OrderedDict() - self.rsp_by_value = collections.OrderedDict() - self.ops = collections.OrderedDict() - self.ntfs = collections.OrderedDict() - self.consts = collections.OrderedDict() - self.mcast_groups = collections.OrderedDict() - - last_exception = None - while len(self._resolution_list) > 0: - resolved = [] - unresolved = self._resolution_list - self._resolution_list = [] - - for elem in unresolved: - try: - elem.resolve() - except (KeyError, AttributeError) as e: - self._resolution_list.append(elem) - last_exception = e - continue - - resolved.append(elem) - - if len(resolved) == 0: - raise last_exception - - def new_enum(self, elem): - return SpecEnumSet(self, elem) - - def new_attr_set(self, elem): - return SpecAttrSet(self, elem) - - def new_struct(self, elem): - return SpecStruct(self, elem) - - def new_sub_message(self, elem): - return SpecSubMessage(self, elem); - - def new_operation(self, elem, req_val, rsp_val): - return SpecOperation(self, elem, req_val, rsp_val) - - def new_mcast_group(self, elem): - return SpecMcastGroup(self, elem) - - def add_unresolved(self, elem): - self._resolution_list.append(elem) - - def _dictify_ops_unified(self): - self.fixed_header = self.yaml['operations'].get('fixed-header') - val = 1 - for elem in self.yaml['operations']['list']: - if 'value' in elem: - val = elem['value'] - - op = self.new_operation(elem, val, val) - val += 1 - - self.msgs[op.name] = op - - def _dictify_ops_directional(self): - self.fixed_header = self.yaml['operations'].get('fixed-header') - req_val = rsp_val = 1 - for elem in self.yaml['operations']['list']: - if 'notify' in elem or 'event' in elem: - if 'value' in elem: - rsp_val = elem['value'] - req_val_next = req_val - rsp_val_next = rsp_val + 1 - req_val = None - elif 'do' in elem or 'dump' in elem: - mode = elem['do'] if 'do' in elem else elem['dump'] - - v = mode.get('request', {}).get('value', None) - if v: - req_val = v - v = mode.get('reply', {}).get('value', None) - if v: - rsp_val = v - - rsp_inc = 1 if 'reply' in mode else 0 - req_val_next = req_val + 1 - rsp_val_next = rsp_val + rsp_inc - else: - raise Exception("Can't parse directional ops") - - if req_val == req_val_next: - req_val = None - if rsp_val == rsp_val_next: - rsp_val = None - - skip = False - for exclude in self._exclude_ops: - skip |= bool(exclude.match(elem['name'])) - if not skip: - op = self.new_operation(elem, req_val, rsp_val) - - req_val = req_val_next - rsp_val = rsp_val_next - - self.msgs[op.name] = op - - def find_operation(self, name): - """ - For a given operation name, find and return operation spec. - """ - for op in self.yaml['operations']['list']: - if name == op['name']: - return op - return None - - def resolve(self): - self.resolve_up(super()) - - definitions = self.yaml.get('definitions', []) - for elem in definitions: - if elem['type'] == 'enum' or elem['type'] == 'flags': - self.consts[elem['name']] = self.new_enum(elem) - elif elem['type'] == 'struct': - self.consts[elem['name']] = self.new_struct(elem) - else: - self.consts[elem['name']] = elem - - for elem in self.yaml['attribute-sets']: - attr_set = self.new_attr_set(elem) - self.attr_sets[elem['name']] = attr_set - - for elem in self.yaml.get('sub-messages', []): - sub_message = self.new_sub_message(elem) - self.sub_msgs[sub_message.name] = sub_message - - if self.msg_id_model == 'unified': - self._dictify_ops_unified() - elif self.msg_id_model == 'directional': - self._dictify_ops_directional() - - for op in self.msgs.values(): - if op.req_value is not None: - self.req_by_value[op.req_value] = op - if op.rsp_value is not None: - self.rsp_by_value[op.rsp_value] = op - if not op.is_async and 'attribute-set' in op: - self.ops[op.name] = op - elif op.is_async: - self.ntfs[op.name] = op - - mcgs = self.yaml.get('mcast-groups') - if mcgs: - for elem in mcgs['list']: - mcg = self.new_mcast_group(elem) - self.mcast_groups[elem['name']] = mcg diff --git a/tools/net/ynl/lib/ynl-priv.h b/tools/net/ynl/lib/ynl-priv.h index 7491da8e7555..3c09a7bbfba5 100644 --- a/tools/net/ynl/lib/ynl-priv.h +++ b/tools/net/ynl/lib/ynl-priv.h @@ -2,16 +2,16 @@ #ifndef __YNL_C_PRIV_H #define __YNL_C_PRIV_H 1 +#include <stdbool.h> #include <stddef.h> -#include <libmnl/libmnl.h> #include <linux/types.h> +struct ynl_parse_arg; + /* * YNL internals / low level stuff */ -/* Generic mnl helper code */ - enum ynl_policy_type { YNL_PT_REJECT = 1, YNL_PT_IGNORE, @@ -27,21 +27,35 @@ enum ynl_policy_type { YNL_PT_BITFIELD32, }; +enum ynl_parse_result { + YNL_PARSE_CB_ERROR = -1, + YNL_PARSE_CB_STOP = 0, + YNL_PARSE_CB_OK = 1, +}; + +#define YNL_SOCKET_BUFFER_SIZE (1 << 17) + +#define YNL_ARRAY_SIZE(array) (sizeof(array) ? \ + sizeof(array) / sizeof(array[0]) : 0) + +typedef int (*ynl_parse_cb_t)(const struct nlmsghdr *nlh, + struct ynl_parse_arg *yarg); + struct ynl_policy_attr { enum ynl_policy_type type; unsigned int len; const char *name; - struct ynl_policy_nest *nest; + const struct ynl_policy_nest *nest; }; struct ynl_policy_nest { unsigned int max_attr; - struct ynl_policy_attr *table; + const struct ynl_policy_attr *table; }; struct ynl_parse_arg { struct ynl_sock *ys; - struct ynl_policy_nest *rsp_policy; + const struct ynl_policy_nest *rsp_policy; void *data; }; @@ -65,7 +79,7 @@ static inline void *ynl_dump_obj_next(void *obj) struct ynl_dump_list_type *list; uptr -= offsetof(struct ynl_dump_list_type, data); - list = (void *)uptr; + list = (struct ynl_dump_list_type *)uptr; uptr = (unsigned long)list->next; uptr += offsetof(struct ynl_dump_list_type, data); @@ -80,8 +94,6 @@ struct ynl_ntf_base_type { unsigned char data[] __attribute__((aligned(8))); }; -extern mnl_cb_t ynl_cb_array[NLMSG_MIN_TYPE]; - struct nlmsghdr * ynl_gemsg_start_req(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version); struct nlmsghdr * @@ -89,30 +101,26 @@ ynl_gemsg_start_dump(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version); int ynl_attr_validate(struct ynl_parse_arg *yarg, const struct nlattr *attr); -int ynl_recv_ack(struct ynl_sock *ys, int ret); -int ynl_cb_null(const struct nlmsghdr *nlh, void *data); - /* YNL specific helpers used by the auto-generated code */ struct ynl_req_state { struct ynl_parse_arg yarg; - mnl_cb_t cb; + ynl_parse_cb_t cb; __u32 rsp_cmd; }; struct ynl_dump_state { - struct ynl_sock *ys; - struct ynl_policy_nest *rsp_policy; + struct ynl_parse_arg yarg; void *first; struct ynl_dump_list_type *last; size_t alloc_sz; - mnl_cb_t cb; + ynl_parse_cb_t cb; __u32 rsp_cmd; }; struct ynl_ntf_info { - struct ynl_policy_nest *policy; - mnl_cb_t cb; + const struct ynl_policy_nest *policy; + ynl_parse_cb_t cb; size_t alloc_sz; void (*free)(struct ynl_ntf_base_type *ntf); }; @@ -125,20 +133,325 @@ int ynl_exec_dump(struct ynl_sock *ys, struct nlmsghdr *req_nlh, void ynl_error_unknown_notification(struct ynl_sock *ys, __u8 cmd); int ynl_error_parse(struct ynl_parse_arg *yarg, const char *msg); -#ifndef MNL_HAS_AUTO_SCALARS -static inline uint64_t mnl_attr_get_uint(const struct nlattr *attr) +/* Netlink message handling helpers */ + +#define YNL_MSG_OVERFLOW 1 + +static inline struct nlmsghdr *ynl_nlmsg_put_header(void *buf) +{ + struct nlmsghdr *nlh = (struct nlmsghdr *)buf; + + memset(nlh, 0, sizeof(*nlh)); + nlh->nlmsg_len = NLMSG_HDRLEN; + + return nlh; +} + +static inline unsigned int ynl_nlmsg_data_len(const struct nlmsghdr *nlh) +{ + return nlh->nlmsg_len - NLMSG_HDRLEN; +} + +static inline void *ynl_nlmsg_data(const struct nlmsghdr *nlh) +{ + return (unsigned char *)nlh + NLMSG_HDRLEN; +} + +static inline void * +ynl_nlmsg_data_offset(const struct nlmsghdr *nlh, unsigned int offset) +{ + return (unsigned char *)nlh + NLMSG_HDRLEN + offset; +} + +static inline void *ynl_nlmsg_end_addr(const struct nlmsghdr *nlh) +{ + return (char *)nlh + nlh->nlmsg_len; +} + +static inline void * +ynl_nlmsg_put_extra_header(struct nlmsghdr *nlh, unsigned int size) +{ + void *tail = ynl_nlmsg_end_addr(nlh); + + nlh->nlmsg_len += NLMSG_ALIGN(size); + return tail; +} + +/* Netlink attribute helpers */ + +static inline unsigned int ynl_attr_type(const struct nlattr *attr) +{ + return attr->nla_type & NLA_TYPE_MASK; +} + +static inline unsigned int ynl_attr_data_len(const struct nlattr *attr) +{ + return attr->nla_len - NLA_HDRLEN; +} + +static inline void *ynl_attr_data(const struct nlattr *attr) +{ + return (unsigned char *)attr + NLA_HDRLEN; +} + +static inline void *ynl_attr_data_end(const struct nlattr *attr) +{ + return (char *)ynl_attr_data(attr) + ynl_attr_data_len(attr); +} + +#define ynl_attr_for_each(attr, nlh, fixed_hdr_sz) \ + for ((attr) = ynl_attr_first(nlh, (nlh)->nlmsg_len, \ + NLMSG_HDRLEN + fixed_hdr_sz); attr; \ + (attr) = ynl_attr_next(ynl_nlmsg_end_addr(nlh), attr)) + +#define ynl_attr_for_each_nested(attr, outer) \ + for ((attr) = ynl_attr_first(outer, outer->nla_len, \ + sizeof(struct nlattr)); attr; \ + (attr) = ynl_attr_next(ynl_attr_data_end(outer), attr)) + +#define ynl_attr_for_each_payload(start, len, attr) \ + for ((attr) = ynl_attr_first(start, len, 0); attr; \ + (attr) = ynl_attr_next(start + len, attr)) + +static inline struct nlattr * +ynl_attr_if_good(const void *end, struct nlattr *attr) +{ + if (attr + 1 > (const struct nlattr *)end) + return NULL; + if (ynl_attr_data_end(attr) > end) + return NULL; + return attr; +} + +static inline struct nlattr * +ynl_attr_next(const void *end, const struct nlattr *prev) +{ + struct nlattr *attr; + + attr = (struct nlattr *)((char *)prev + NLA_ALIGN(prev->nla_len)); + return ynl_attr_if_good(end, attr); +} + +static inline struct nlattr * +ynl_attr_first(const void *start, size_t len, size_t skip) +{ + struct nlattr *attr; + + attr = (struct nlattr *)((char *)start + NLMSG_ALIGN(skip)); + return ynl_attr_if_good((char *)start + len, attr); +} + +static inline bool +__ynl_attr_put_overflow(struct nlmsghdr *nlh, size_t size) +{ + bool o; + + /* ynl_msg_start() stashed buffer length in nlmsg_pid. */ + o = nlh->nlmsg_len + NLA_HDRLEN + NLMSG_ALIGN(size) > nlh->nlmsg_pid; + if (o) + /* YNL_MSG_OVERFLOW is < NLMSG_HDRLEN, all subsequent checks + * are guaranteed to fail. + */ + nlh->nlmsg_pid = YNL_MSG_OVERFLOW; + return o; +} + +static inline struct nlattr * +ynl_attr_nest_start(struct nlmsghdr *nlh, unsigned int attr_type) { - if (mnl_attr_get_payload_len(attr) == 4) - return mnl_attr_get_u32(attr); - return mnl_attr_get_u64(attr); + struct nlattr *attr; + + if (__ynl_attr_put_overflow(nlh, 0)) + return (struct nlattr *)ynl_nlmsg_end_addr(nlh) - 1; + + attr = (struct nlattr *)ynl_nlmsg_end_addr(nlh); + attr->nla_type = attr_type | NLA_F_NESTED; + nlh->nlmsg_len += NLA_HDRLEN; + + return attr; } static inline void -mnl_attr_put_uint(struct nlmsghdr *nlh, uint16_t type, uint64_t data) +ynl_attr_nest_end(struct nlmsghdr *nlh, struct nlattr *attr) { - if ((uint32_t)data == (uint64_t)data) - return mnl_attr_put_u32(nlh, type, data); - return mnl_attr_put_u64(nlh, type, data); + attr->nla_len = (char *)ynl_nlmsg_end_addr(nlh) - (char *)attr; +} + +static inline void +ynl_attr_put(struct nlmsghdr *nlh, unsigned int attr_type, + const void *value, size_t size) +{ + struct nlattr *attr; + + if (__ynl_attr_put_overflow(nlh, size)) + return; + + attr = (struct nlattr *)ynl_nlmsg_end_addr(nlh); + attr->nla_type = attr_type; + attr->nla_len = NLA_HDRLEN + size; + + memcpy(ynl_attr_data(attr), value, size); + + nlh->nlmsg_len += NLMSG_ALIGN(attr->nla_len); +} + +static inline void +ynl_attr_put_str(struct nlmsghdr *nlh, unsigned int attr_type, const char *str) +{ + struct nlattr *attr; + size_t len; + + len = strlen(str); + if (__ynl_attr_put_overflow(nlh, len)) + return; + + attr = (struct nlattr *)ynl_nlmsg_end_addr(nlh); + attr->nla_type = attr_type; + + strcpy((char *)ynl_attr_data(attr), str); + attr->nla_len = NLA_HDRLEN + NLA_ALIGN(len); + + nlh->nlmsg_len += NLMSG_ALIGN(attr->nla_len); +} + +static inline const char *ynl_attr_get_str(const struct nlattr *attr) +{ + return (const char *)ynl_attr_data(attr); +} + +static inline __s8 ynl_attr_get_s8(const struct nlattr *attr) +{ + return *(__s8 *)ynl_attr_data(attr); +} + +static inline __s16 ynl_attr_get_s16(const struct nlattr *attr) +{ + return *(__s16 *)ynl_attr_data(attr); +} + +static inline __s32 ynl_attr_get_s32(const struct nlattr *attr) +{ + return *(__s32 *)ynl_attr_data(attr); +} + +static inline __s64 ynl_attr_get_s64(const struct nlattr *attr) +{ + __s64 tmp; + + memcpy(&tmp, (unsigned char *)(attr + 1), sizeof(tmp)); + return tmp; +} + +static inline __u8 ynl_attr_get_u8(const struct nlattr *attr) +{ + return *(__u8 *)ynl_attr_data(attr); +} + +static inline __u16 ynl_attr_get_u16(const struct nlattr *attr) +{ + return *(__u16 *)ynl_attr_data(attr); +} + +static inline __u32 ynl_attr_get_u32(const struct nlattr *attr) +{ + return *(__u32 *)ynl_attr_data(attr); +} + +static inline __u64 ynl_attr_get_u64(const struct nlattr *attr) +{ + __u64 tmp; + + memcpy(&tmp, (unsigned char *)(attr + 1), sizeof(tmp)); + return tmp; +} + +static inline void +ynl_attr_put_s8(struct nlmsghdr *nlh, unsigned int attr_type, __s8 value) +{ + ynl_attr_put(nlh, attr_type, &value, sizeof(value)); +} + +static inline void +ynl_attr_put_s16(struct nlmsghdr *nlh, unsigned int attr_type, __s16 value) +{ + ynl_attr_put(nlh, attr_type, &value, sizeof(value)); +} + +static inline void +ynl_attr_put_s32(struct nlmsghdr *nlh, unsigned int attr_type, __s32 value) +{ + ynl_attr_put(nlh, attr_type, &value, sizeof(value)); +} + +static inline void +ynl_attr_put_s64(struct nlmsghdr *nlh, unsigned int attr_type, __s64 value) +{ + ynl_attr_put(nlh, attr_type, &value, sizeof(value)); +} + +static inline void +ynl_attr_put_u8(struct nlmsghdr *nlh, unsigned int attr_type, __u8 value) +{ + ynl_attr_put(nlh, attr_type, &value, sizeof(value)); +} + +static inline void +ynl_attr_put_u16(struct nlmsghdr *nlh, unsigned int attr_type, __u16 value) +{ + ynl_attr_put(nlh, attr_type, &value, sizeof(value)); +} + +static inline void +ynl_attr_put_u32(struct nlmsghdr *nlh, unsigned int attr_type, __u32 value) +{ + ynl_attr_put(nlh, attr_type, &value, sizeof(value)); +} + +static inline void +ynl_attr_put_u64(struct nlmsghdr *nlh, unsigned int attr_type, __u64 value) +{ + ynl_attr_put(nlh, attr_type, &value, sizeof(value)); +} + +static inline __u64 ynl_attr_get_uint(const struct nlattr *attr) +{ + switch (ynl_attr_data_len(attr)) { + case 4: + return ynl_attr_get_u32(attr); + case 8: + return ynl_attr_get_u64(attr); + default: + return 0; + } +} + +static inline __s64 ynl_attr_get_sint(const struct nlattr *attr) +{ + switch (ynl_attr_data_len(attr)) { + case 4: + return ynl_attr_get_s32(attr); + case 8: + return ynl_attr_get_s64(attr); + default: + return 0; + } +} + +static inline void +ynl_attr_put_uint(struct nlmsghdr *nlh, __u16 type, __u64 data) +{ + if ((__u32)data == (__u64)data) + ynl_attr_put_u32(nlh, type, data); + else + ynl_attr_put_u64(nlh, type, data); +} + +static inline void +ynl_attr_put_sint(struct nlmsghdr *nlh, __u16 type, __s64 data) +{ + if ((__s32)data == (__s64)data) + ynl_attr_put_s32(nlh, type, data); + else + ynl_attr_put_s64(nlh, type, data); } -#endif #endif diff --git a/tools/net/ynl/lib/ynl.c b/tools/net/ynl/lib/ynl.c index 45e49671ae87..ce32cb35007d 100644 --- a/tools/net/ynl/lib/ynl.c +++ b/tools/net/ynl/lib/ynl.c @@ -3,10 +3,11 @@ #include <poll.h> #include <string.h> #include <stdlib.h> +#include <stdio.h> +#include <unistd.h> #include <linux/types.h> - -#include <libmnl/libmnl.h> #include <linux/genetlink.h> +#include <sys/socket.h> #include "ynl.h" @@ -45,7 +46,7 @@ /* -- Netlink boiler plate */ static int -ynl_err_walk_report_one(struct ynl_policy_nest *policy, unsigned int type, +ynl_err_walk_report_one(const struct ynl_policy_nest *policy, unsigned int type, char *str, int str_sz, int *n) { if (!policy) { @@ -74,8 +75,8 @@ ynl_err_walk_report_one(struct ynl_policy_nest *policy, unsigned int type, static int ynl_err_walk(struct ynl_sock *ys, void *start, void *end, unsigned int off, - struct ynl_policy_nest *policy, char *str, int str_sz, - struct ynl_policy_nest **nest_pol) + const struct ynl_policy_nest *policy, char *str, int str_sz, + const struct ynl_policy_nest **nest_pol) { unsigned int astart_off, aend_off; const struct nlattr *attr; @@ -92,9 +93,9 @@ ynl_err_walk(struct ynl_sock *ys, void *start, void *end, unsigned int off, data_len = end - start; - mnl_attr_for_each_payload(start, data_len) { + ynl_attr_for_each_payload(start, data_len, attr) { astart_off = (char *)attr - (char *)start; - aend_off = astart_off + mnl_attr_get_payload_len(attr); + aend_off = (char *)ynl_attr_data_end(attr) - (char *)start; if (aend_off <= off) continue; @@ -106,7 +107,7 @@ ynl_err_walk(struct ynl_sock *ys, void *start, void *end, unsigned int off, off -= astart_off; - type = mnl_attr_get_type(attr); + type = ynl_attr_type(attr); if (ynl_err_walk_report_one(policy, type, str, str_sz, &n)) return n; @@ -124,8 +125,8 @@ ynl_err_walk(struct ynl_sock *ys, void *start, void *end, unsigned int off, } off -= sizeof(struct nlattr); - start = mnl_attr_get_payload(attr); - end = start + mnl_attr_get_payload_len(attr); + start = ynl_attr_data(attr); + end = start + ynl_attr_data_len(attr); return n + ynl_err_walk(ys, start, end, off, policy->table[type].nest, &str[n], str_sz - n, nest_pol); @@ -147,14 +148,14 @@ ynl_ext_ack_check(struct ynl_sock *ys, const struct nlmsghdr *nlh, if (!(nlh->nlmsg_flags & NLM_F_ACK_TLVS)) { yerr_msg(ys, "%s", strerror(ys->err.code)); - return MNL_CB_OK; + return YNL_PARSE_CB_OK; } - mnl_attr_for_each(attr, nlh, hlen) { + ynl_attr_for_each(attr, nlh, hlen) { unsigned int len, type; - len = mnl_attr_get_payload_len(attr); - type = mnl_attr_get_type(attr); + len = ynl_attr_data_len(attr); + type = ynl_attr_type(attr); if (type > NLMSGERR_ATTR_MAX) continue; @@ -166,12 +167,12 @@ ynl_ext_ack_check(struct ynl_sock *ys, const struct nlmsghdr *nlh, case NLMSGERR_ATTR_MISS_TYPE: case NLMSGERR_ATTR_MISS_NEST: if (len != sizeof(__u32)) - return MNL_CB_ERROR; + return YNL_PARSE_CB_ERROR; break; case NLMSGERR_ATTR_MSG: - str = mnl_attr_get_payload(attr); + str = ynl_attr_get_str(attr); if (str[len - 1]) - return MNL_CB_ERROR; + return YNL_PARSE_CB_ERROR; break; default: break; @@ -185,14 +186,13 @@ ynl_ext_ack_check(struct ynl_sock *ys, const struct nlmsghdr *nlh, unsigned int n, off; void *start, *end; - ys->err.attr_offs = mnl_attr_get_u32(tb[NLMSGERR_ATTR_OFFS]); + ys->err.attr_offs = ynl_attr_get_u32(tb[NLMSGERR_ATTR_OFFS]); n = snprintf(bad_attr, sizeof(bad_attr), "%sbad attribute: ", str ? " (" : ""); - start = mnl_nlmsg_get_payload_offset(ys->nlh, - ys->family->hdr_len); - end = mnl_nlmsg_get_payload_tail(ys->nlh); + start = ynl_nlmsg_data_offset(ys->nlh, ys->family->hdr_len); + end = ynl_nlmsg_end_addr(ys->nlh); off = ys->err.attr_offs; off -= sizeof(struct nlmsghdr); @@ -206,23 +206,22 @@ ynl_ext_ack_check(struct ynl_sock *ys, const struct nlmsghdr *nlh, bad_attr[n] = '\0'; } if (tb[NLMSGERR_ATTR_MISS_TYPE]) { - struct ynl_policy_nest *nest_pol = NULL; + const struct ynl_policy_nest *nest_pol = NULL; unsigned int n, off, type; void *start, *end; int n2; - type = mnl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_TYPE]); + type = ynl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_TYPE]); n = snprintf(miss_attr, sizeof(miss_attr), "%smissing attribute: ", bad_attr[0] ? ", " : (str ? " (" : "")); - start = mnl_nlmsg_get_payload_offset(ys->nlh, - ys->family->hdr_len); - end = mnl_nlmsg_get_payload_tail(ys->nlh); + start = ynl_nlmsg_data_offset(ys->nlh, ys->family->hdr_len); + end = ynl_nlmsg_end_addr(ys->nlh); nest_pol = ys->req_policy; if (tb[NLMSGERR_ATTR_MISS_NEST]) { - off = mnl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_NEST]); + off = ynl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_NEST]); off -= sizeof(struct nlmsghdr); off -= ys->family->hdr_len; @@ -254,13 +253,13 @@ ynl_ext_ack_check(struct ynl_sock *ys, const struct nlmsghdr *nlh, else yerr_msg(ys, "%s", strerror(ys->err.code)); - return MNL_CB_OK; + return YNL_PARSE_CB_OK; } -static int ynl_cb_error(const struct nlmsghdr *nlh, void *data) +static int +ynl_cb_error(const struct nlmsghdr *nlh, struct ynl_parse_arg *yarg) { - const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh); - struct ynl_parse_arg *yarg = data; + const struct nlmsgerr *err = ynl_nlmsg_data(nlh); unsigned int hlen; int code; @@ -270,16 +269,15 @@ static int ynl_cb_error(const struct nlmsghdr *nlh, void *data) hlen = sizeof(*err); if (!(nlh->nlmsg_flags & NLM_F_CAPPED)) - hlen += mnl_nlmsg_get_payload_len(&err->msg); + hlen += ynl_nlmsg_data_len(&err->msg); ynl_ext_ack_check(yarg->ys, nlh, hlen); - return code ? MNL_CB_ERROR : MNL_CB_STOP; + return code ? YNL_PARSE_CB_ERROR : YNL_PARSE_CB_STOP; } -static int ynl_cb_done(const struct nlmsghdr *nlh, void *data) +static int ynl_cb_done(const struct nlmsghdr *nlh, struct ynl_parse_arg *yarg) { - struct ynl_parse_arg *yarg = data; int err; err = *(int *)NLMSG_DATA(nlh); @@ -289,34 +287,22 @@ static int ynl_cb_done(const struct nlmsghdr *nlh, void *data) ynl_ext_ack_check(yarg->ys, nlh, sizeof(int)); - return MNL_CB_ERROR; + return YNL_PARSE_CB_ERROR; } - return MNL_CB_STOP; -} - -static int ynl_cb_noop(const struct nlmsghdr *nlh, void *data) -{ - return MNL_CB_OK; + return YNL_PARSE_CB_STOP; } -mnl_cb_t ynl_cb_array[NLMSG_MIN_TYPE] = { - [NLMSG_NOOP] = ynl_cb_noop, - [NLMSG_ERROR] = ynl_cb_error, - [NLMSG_DONE] = ynl_cb_done, - [NLMSG_OVERRUN] = ynl_cb_noop, -}; - /* Attribute validation */ int ynl_attr_validate(struct ynl_parse_arg *yarg, const struct nlattr *attr) { - struct ynl_policy_attr *policy; + const struct ynl_policy_attr *policy; unsigned int type, len; unsigned char *data; - data = mnl_attr_get_payload(attr); - len = mnl_attr_get_payload_len(attr); - type = mnl_attr_get_type(attr); + data = ynl_attr_data(attr); + len = ynl_attr_data_len(attr); + type = ynl_attr_type(attr); if (type > yarg->rsp_policy->max_attr) { yerr(yarg->ys, YNL_ERROR_INTERNAL, "Internal error, validating unknown attribute"); @@ -413,14 +399,38 @@ struct nlmsghdr *ynl_msg_start(struct ynl_sock *ys, __u32 id, __u16 flags) ynl_err_reset(ys); - nlh = ys->nlh = mnl_nlmsg_put_header(ys->tx_buf); + nlh = ys->nlh = ynl_nlmsg_put_header(ys->tx_buf); nlh->nlmsg_type = id; nlh->nlmsg_flags = flags; nlh->nlmsg_seq = ++ys->seq; + /* This is a local YNL hack for length checking, we put the buffer + * length in nlmsg_pid, since messages sent to the kernel always use + * PID 0. Message needs to be terminated with ynl_msg_end(). + */ + nlh->nlmsg_pid = YNL_SOCKET_BUFFER_SIZE; + return nlh; } +static int ynl_msg_end(struct ynl_sock *ys, struct nlmsghdr *nlh) +{ + /* We stash buffer length in nlmsg_pid. */ + if (nlh->nlmsg_pid == 0) { + yerr(ys, YNL_ERROR_INPUT_INVALID, + "Unknown input buffer length"); + return -EINVAL; + } + if (nlh->nlmsg_pid == YNL_MSG_OVERFLOW) { + yerr(ys, YNL_ERROR_INPUT_TOO_BIG, + "Constructed message longer than internal buffer"); + return -EMSGSIZE; + } + + nlh->nlmsg_pid = 0; + return 0; +} + struct nlmsghdr * ynl_gemsg_start(struct ynl_sock *ys, __u32 id, __u16 flags, __u8 cmd, __u8 version) @@ -435,7 +445,7 @@ ynl_gemsg_start(struct ynl_sock *ys, __u32 id, __u16 flags, gehdr.cmd = cmd; gehdr.version = version; - data = mnl_nlmsg_put_extra_header(nlh, sizeof(gehdr)); + data = ynl_nlmsg_put_extra_header(nlh, sizeof(gehdr)); memcpy(data, &gehdr, sizeof(gehdr)); return nlh; @@ -464,33 +474,85 @@ ynl_gemsg_start_dump(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version) cmd, version); } -int ynl_recv_ack(struct ynl_sock *ys, int ret) +static int ynl_cb_null(const struct nlmsghdr *nlh, struct ynl_parse_arg *yarg) { - struct ynl_parse_arg yarg = { .ys = ys, }; + yerr(yarg->ys, YNL_ERROR_UNEXPECT_MSG, + "Received a message when none were expected"); - if (!ret) { - yerr(ys, YNL_ERROR_EXPECT_ACK, - "Expecting an ACK but nothing received"); - return -1; + return YNL_PARSE_CB_ERROR; +} + +static int +__ynl_sock_read_msgs(struct ynl_parse_arg *yarg, ynl_parse_cb_t cb, int flags) +{ + struct ynl_sock *ys = yarg->ys; + const struct nlmsghdr *nlh; + ssize_t len, rem; + int ret; + + len = recv(ys->socket, ys->rx_buf, YNL_SOCKET_BUFFER_SIZE, flags); + if (len < 0) { + if (flags & MSG_DONTWAIT && errno == EAGAIN) + return YNL_PARSE_CB_STOP; + return len; } - ret = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE); - if (ret < 0) { - perr(ys, "Socket receive failed"); - return ret; + ret = YNL_PARSE_CB_STOP; + for (rem = len; rem > 0; NLMSG_NEXT(nlh, rem)) { + nlh = (struct nlmsghdr *)&ys->rx_buf[len - rem]; + if (!NLMSG_OK(nlh, rem)) { + yerr(yarg->ys, YNL_ERROR_INV_RESP, + "Invalid message or trailing data in the response."); + return YNL_PARSE_CB_ERROR; + } + + if (nlh->nlmsg_flags & NLM_F_DUMP_INTR) { + /* TODO: handle this better */ + yerr(yarg->ys, YNL_ERROR_DUMP_INTER, + "Dump interrupted / inconsistent, please retry."); + return YNL_PARSE_CB_ERROR; + } + + switch (nlh->nlmsg_type) { + case 0: + yerr(yarg->ys, YNL_ERROR_INV_RESP, + "Invalid message type in the response."); + return YNL_PARSE_CB_ERROR; + case NLMSG_NOOP: + case NLMSG_OVERRUN ... NLMSG_MIN_TYPE - 1: + ret = YNL_PARSE_CB_OK; + break; + case NLMSG_ERROR: + ret = ynl_cb_error(nlh, yarg); + break; + case NLMSG_DONE: + ret = ynl_cb_done(nlh, yarg); + break; + default: + ret = cb(nlh, yarg); + break; + } } - return mnl_cb_run(ys->rx_buf, ret, ys->seq, ys->portid, - ynl_cb_null, &yarg); + + return ret; } -int ynl_cb_null(const struct nlmsghdr *nlh, void *data) +static int ynl_sock_read_msgs(struct ynl_parse_arg *yarg, ynl_parse_cb_t cb) { - struct ynl_parse_arg *yarg = data; + return __ynl_sock_read_msgs(yarg, cb, 0); +} - yerr(yarg->ys, YNL_ERROR_UNEXPECT_MSG, - "Received a message when none were expected"); +static int ynl_recv_ack(struct ynl_sock *ys, int ret) +{ + struct ynl_parse_arg yarg = { .ys = ys, }; + + if (!ret) { + yerr(ys, YNL_ERROR_EXPECT_ACK, + "Expecting an ACK but nothing received"); + return -1; + } - return MNL_CB_ERROR; + return ynl_sock_read_msgs(&yarg, ynl_cb_null); } /* Init/fini and genetlink boiler plate */ @@ -500,7 +562,7 @@ ynl_get_family_info_mcast(struct ynl_sock *ys, const struct nlattr *mcasts) const struct nlattr *entry, *attr; unsigned int i; - mnl_attr_for_each_nested(attr, mcasts) + ynl_attr_for_each_nested(attr, mcasts) ys->n_mcast_groups++; if (!ys->n_mcast_groups) @@ -509,16 +571,16 @@ ynl_get_family_info_mcast(struct ynl_sock *ys, const struct nlattr *mcasts) ys->mcast_groups = calloc(ys->n_mcast_groups, sizeof(*ys->mcast_groups)); if (!ys->mcast_groups) - return MNL_CB_ERROR; + return YNL_PARSE_CB_ERROR; i = 0; - mnl_attr_for_each_nested(entry, mcasts) { - mnl_attr_for_each_nested(attr, entry) { - if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GRP_ID) - ys->mcast_groups[i].id = mnl_attr_get_u32(attr); - if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GRP_NAME) { + ynl_attr_for_each_nested(entry, mcasts) { + ynl_attr_for_each_nested(attr, entry) { + if (ynl_attr_type(attr) == CTRL_ATTR_MCAST_GRP_ID) + ys->mcast_groups[i].id = ynl_attr_get_u32(attr); + if (ynl_attr_type(attr) == CTRL_ATTR_MCAST_GRP_NAME) { strncpy(ys->mcast_groups[i].name, - mnl_attr_get_str(attr), + ynl_attr_get_str(attr), GENL_NAMSIZ - 1); ys->mcast_groups[i].name[GENL_NAMSIZ - 1] = 0; } @@ -529,35 +591,35 @@ ynl_get_family_info_mcast(struct ynl_sock *ys, const struct nlattr *mcasts) return 0; } -static int ynl_get_family_info_cb(const struct nlmsghdr *nlh, void *data) +static int +ynl_get_family_info_cb(const struct nlmsghdr *nlh, struct ynl_parse_arg *yarg) { - struct ynl_parse_arg *yarg = data; struct ynl_sock *ys = yarg->ys; const struct nlattr *attr; bool found_id = true; - mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr)) { - if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GROUPS) + ynl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr)) { + if (ynl_attr_type(attr) == CTRL_ATTR_MCAST_GROUPS) if (ynl_get_family_info_mcast(ys, attr)) - return MNL_CB_ERROR; + return YNL_PARSE_CB_ERROR; - if (mnl_attr_get_type(attr) != CTRL_ATTR_FAMILY_ID) + if (ynl_attr_type(attr) != CTRL_ATTR_FAMILY_ID) continue; - if (mnl_attr_get_payload_len(attr) != sizeof(__u16)) { + if (ynl_attr_data_len(attr) != sizeof(__u16)) { yerr(ys, YNL_ERROR_ATTR_INVALID, "Invalid family ID"); - return MNL_CB_ERROR; + return YNL_PARSE_CB_ERROR; } - ys->family_id = mnl_attr_get_u16(attr); + ys->family_id = ynl_attr_get_u16(attr); found_id = true; } if (!found_id) { yerr(ys, YNL_ERROR_ATTR_MISSING, "Family ID missing"); - return MNL_CB_ERROR; + return YNL_PARSE_CB_ERROR; } - return MNL_CB_OK; + return YNL_PARSE_CB_OK; } static int ynl_sock_read_family(struct ynl_sock *ys, const char *family_name) @@ -567,22 +629,19 @@ static int ynl_sock_read_family(struct ynl_sock *ys, const char *family_name) int err; nlh = ynl_gemsg_start_req(ys, GENL_ID_CTRL, CTRL_CMD_GETFAMILY, 1); - mnl_attr_put_strz(nlh, CTRL_ATTR_FAMILY_NAME, family_name); + ynl_attr_put_str(nlh, CTRL_ATTR_FAMILY_NAME, family_name); - err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len); + err = ynl_msg_end(ys, nlh); + if (err < 0) + return err; + + err = send(ys->socket, nlh, nlh->nlmsg_len, 0); if (err < 0) { perr(ys, "failed to request socket family info"); return err; } - err = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE); - if (err <= 0) { - perr(ys, "failed to receive the socket family info"); - return err; - } - err = mnl_cb_run2(ys->rx_buf, err, ys->seq, ys->portid, - ynl_get_family_info_cb, &yarg, - ynl_cb_array, ARRAY_SIZE(ynl_cb_array)); + err = ynl_sock_read_msgs(&yarg, ynl_get_family_info_cb); if (err < 0) { free(ys->mcast_groups); perr(ys, "failed to receive the socket family info - no such family?"); @@ -601,38 +660,54 @@ static int ynl_sock_read_family(struct ynl_sock *ys, const char *family_name) struct ynl_sock * ynl_sock_create(const struct ynl_family *yf, struct ynl_error *yse) { + struct sockaddr_nl addr; struct ynl_sock *ys; + socklen_t addrlen; int one = 1; - ys = malloc(sizeof(*ys) + 2 * MNL_SOCKET_BUFFER_SIZE); + ys = malloc(sizeof(*ys) + 2 * YNL_SOCKET_BUFFER_SIZE); if (!ys) return NULL; memset(ys, 0, sizeof(*ys)); ys->family = yf; ys->tx_buf = &ys->raw_buf[0]; - ys->rx_buf = &ys->raw_buf[MNL_SOCKET_BUFFER_SIZE]; + ys->rx_buf = &ys->raw_buf[YNL_SOCKET_BUFFER_SIZE]; ys->ntf_last_next = &ys->ntf_first; - ys->sock = mnl_socket_open(NETLINK_GENERIC); - if (!ys->sock) { + ys->socket = socket(AF_NETLINK, SOCK_RAW, NETLINK_GENERIC); + if (ys->socket < 0) { __perr(yse, "failed to create a netlink socket"); goto err_free_sock; } - if (mnl_socket_setsockopt(ys->sock, NETLINK_CAP_ACK, - &one, sizeof(one))) { + if (setsockopt(ys->socket, SOL_NETLINK, NETLINK_CAP_ACK, + &one, sizeof(one))) { __perr(yse, "failed to enable netlink ACK"); goto err_close_sock; } - if (mnl_socket_setsockopt(ys->sock, NETLINK_EXT_ACK, - &one, sizeof(one))) { + if (setsockopt(ys->socket, SOL_NETLINK, NETLINK_EXT_ACK, + &one, sizeof(one))) { __perr(yse, "failed to enable netlink ext ACK"); goto err_close_sock; } + memset(&addr, 0, sizeof(addr)); + addr.nl_family = AF_NETLINK; + if (bind(ys->socket, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + __perr(yse, "unable to bind to a socket address"); + goto err_close_sock; + } + + memset(&addr, 0, sizeof(addr)); + addrlen = sizeof(addr); + if (getsockname(ys->socket, (struct sockaddr *)&addr, &addrlen) < 0) { + __perr(yse, "unable to read socket address"); + goto err_close_sock; + } + ys->portid = addr.nl_pid; ys->seq = random(); - ys->portid = mnl_socket_get_portid(ys->sock); + if (ynl_sock_read_family(ys, yf->name)) { if (yse) @@ -643,7 +718,7 @@ ynl_sock_create(const struct ynl_family *yf, struct ynl_error *yse) return ys; err_close_sock: - mnl_socket_close(ys->sock); + close(ys->socket); err_free_sock: free(ys); return NULL; @@ -653,7 +728,7 @@ void ynl_sock_destroy(struct ynl_sock *ys) { struct ynl_ntf_base_type *ntf; - mnl_socket_close(ys->sock); + close(ys->socket); while ((ntf = ynl_ntf_dequeue(ys))) ynl_ntf_free(ntf); free(ys->mcast_groups); @@ -680,9 +755,9 @@ int ynl_subscribe(struct ynl_sock *ys, const char *grp_name) return -1; } - err = mnl_socket_setsockopt(ys->sock, NETLINK_ADD_MEMBERSHIP, - &ys->mcast_groups[i].id, - sizeof(ys->mcast_groups[i].id)); + err = setsockopt(ys->socket, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, + &ys->mcast_groups[i].id, + sizeof(ys->mcast_groups[i].id)); if (err < 0) { perr(ys, "Subscribing to multicast group failed"); return -1; @@ -693,7 +768,7 @@ int ynl_subscribe(struct ynl_sock *ys, const char *grp_name) int ynl_socket_get_fd(struct ynl_sock *ys) { - return mnl_socket_get_fd(ys->sock); + return ys->socket; } struct ynl_ntf_base_type *ynl_ntf_dequeue(struct ynl_sock *ys) @@ -719,12 +794,12 @@ static int ynl_ntf_parse(struct ynl_sock *ys, const struct nlmsghdr *nlh) struct genlmsghdr *gehdr; int ret; - gehdr = mnl_nlmsg_get_payload(nlh); + gehdr = ynl_nlmsg_data(nlh); if (gehdr->cmd >= ys->family->ntf_info_size) - return MNL_CB_ERROR; + return YNL_PARSE_CB_ERROR; info = &ys->family->ntf_info[gehdr->cmd]; if (!info->cb) - return MNL_CB_ERROR; + return YNL_PARSE_CB_ERROR; rsp = calloc(1, info->alloc_sz); rsp->free = info->free; @@ -732,7 +807,7 @@ static int ynl_ntf_parse(struct ynl_sock *ys, const struct nlmsghdr *nlh) yarg.rsp_policy = info->policy; ret = info->cb(nlh, &yarg); - if (ret <= MNL_CB_STOP) + if (ret <= YNL_PARSE_CB_STOP) goto err_free; rsp->family = nlh->nlmsg_type; @@ -741,46 +816,27 @@ static int ynl_ntf_parse(struct ynl_sock *ys, const struct nlmsghdr *nlh) *ys->ntf_last_next = rsp; ys->ntf_last_next = &rsp->next; - return MNL_CB_OK; + return YNL_PARSE_CB_OK; err_free: info->free(rsp); - return MNL_CB_ERROR; + return YNL_PARSE_CB_ERROR; } -static int ynl_ntf_trampoline(const struct nlmsghdr *nlh, void *data) +static int +ynl_ntf_trampoline(const struct nlmsghdr *nlh, struct ynl_parse_arg *yarg) { - struct ynl_parse_arg *yarg = data; - return ynl_ntf_parse(yarg->ys, nlh); } int ynl_ntf_check(struct ynl_sock *ys) { struct ynl_parse_arg yarg = { .ys = ys, }; - ssize_t len; int err; do { - /* libmnl doesn't let us pass flags to the recv to make - * it non-blocking so we need to poll() or peek() :| - */ - struct pollfd pfd = { }; - - pfd.fd = mnl_socket_get_fd(ys->sock); - pfd.events = POLLIN; - err = poll(&pfd, 1, 1); - if (err < 1) - return err; - - len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, - MNL_SOCKET_BUFFER_SIZE); - if (len < 0) - return len; - - err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid, - ynl_ntf_trampoline, &yarg, - ynl_cb_array, NLMSG_MIN_TYPE); + err = __ynl_sock_read_msgs(&yarg, ynl_ntf_trampoline, + MSG_DONTWAIT); if (err < 0) return err; } while (err > 0); @@ -801,7 +857,7 @@ void ynl_error_unknown_notification(struct ynl_sock *ys, __u8 cmd) int ynl_error_parse(struct ynl_parse_arg *yarg, const char *msg) { yerr(yarg->ys, YNL_ERROR_INV_RESP, "Error parsing response: %s", msg); - return MNL_CB_ERROR; + return YNL_PARSE_CB_ERROR; } static int @@ -809,27 +865,28 @@ ynl_check_alien(struct ynl_sock *ys, const struct nlmsghdr *nlh, __u32 rsp_cmd) { struct genlmsghdr *gehdr; - if (mnl_nlmsg_get_payload_len(nlh) < sizeof(*gehdr)) { + if (ynl_nlmsg_data_len(nlh) < sizeof(*gehdr)) { yerr(ys, YNL_ERROR_INV_RESP, "Kernel responded with truncated message"); return -1; } - gehdr = mnl_nlmsg_get_payload(nlh); + gehdr = ynl_nlmsg_data(nlh); if (gehdr->cmd != rsp_cmd) return ynl_ntf_parse(ys, nlh); return 0; } -static int ynl_req_trampoline(const struct nlmsghdr *nlh, void *data) +static +int ynl_req_trampoline(const struct nlmsghdr *nlh, struct ynl_parse_arg *yarg) { - struct ynl_req_state *yrs = data; + struct ynl_req_state *yrs = (void *)yarg; int ret; ret = ynl_check_alien(yrs->yarg.ys, nlh, yrs->rsp_cmd); if (ret) - return ret < 0 ? MNL_CB_ERROR : MNL_CB_OK; + return ret < 0 ? YNL_PARSE_CB_ERROR : YNL_PARSE_CB_OK; return yrs->cb(nlh, &yrs->yarg); } @@ -837,43 +894,38 @@ static int ynl_req_trampoline(const struct nlmsghdr *nlh, void *data) int ynl_exec(struct ynl_sock *ys, struct nlmsghdr *req_nlh, struct ynl_req_state *yrs) { - ssize_t len; int err; - err = mnl_socket_sendto(ys->sock, req_nlh, req_nlh->nlmsg_len); + err = ynl_msg_end(ys, req_nlh); + if (err < 0) + return err; + + err = send(ys->socket, req_nlh, req_nlh->nlmsg_len, 0); if (err < 0) return err; do { - len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, - MNL_SOCKET_BUFFER_SIZE); - if (len < 0) - return len; - - err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid, - ynl_req_trampoline, yrs, - ynl_cb_array, NLMSG_MIN_TYPE); - if (err < 0) - return err; + err = ynl_sock_read_msgs(&yrs->yarg, ynl_req_trampoline); } while (err > 0); - return 0; + return err; } -static int ynl_dump_trampoline(const struct nlmsghdr *nlh, void *data) +static int +ynl_dump_trampoline(const struct nlmsghdr *nlh, struct ynl_parse_arg *data) { - struct ynl_dump_state *ds = data; + struct ynl_dump_state *ds = (void *)data; struct ynl_dump_list_type *obj; struct ynl_parse_arg yarg = {}; int ret; - ret = ynl_check_alien(ds->ys, nlh, ds->rsp_cmd); + ret = ynl_check_alien(ds->yarg.ys, nlh, ds->rsp_cmd); if (ret) - return ret < 0 ? MNL_CB_ERROR : MNL_CB_OK; + return ret < 0 ? YNL_PARSE_CB_ERROR : YNL_PARSE_CB_OK; obj = calloc(1, ds->alloc_sz); if (!obj) - return MNL_CB_ERROR; + return YNL_PARSE_CB_ERROR; if (!ds->first) ds->first = obj; @@ -881,8 +933,7 @@ static int ynl_dump_trampoline(const struct nlmsghdr *nlh, void *data) ds->last->next = obj; ds->last = obj; - yarg.ys = ds->ys; - yarg.rsp_policy = ds->rsp_policy; + yarg = ds->yarg; yarg.data = &obj->data; return ds->cb(nlh, &yarg); @@ -900,22 +951,18 @@ static void *ynl_dump_end(struct ynl_dump_state *ds) int ynl_exec_dump(struct ynl_sock *ys, struct nlmsghdr *req_nlh, struct ynl_dump_state *yds) { - ssize_t len; int err; - err = mnl_socket_sendto(ys->sock, req_nlh, req_nlh->nlmsg_len); + err = ynl_msg_end(ys, req_nlh); if (err < 0) return err; - do { - len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, - MNL_SOCKET_BUFFER_SIZE); - if (len < 0) - goto err_close_list; + err = send(ys->socket, req_nlh, req_nlh->nlmsg_len, 0); + if (err < 0) + return err; - err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid, - ynl_dump_trampoline, yds, - ynl_cb_array, NLMSG_MIN_TYPE); + do { + err = ynl_sock_read_msgs(&yds->yarg, ynl_dump_trampoline); if (err < 0) goto err_close_list; } while (err > 0); diff --git a/tools/net/ynl/lib/ynl.h b/tools/net/ynl/lib/ynl.h index ce77a6d76ce0..6cd570b283ea 100644 --- a/tools/net/ynl/lib/ynl.h +++ b/tools/net/ynl/lib/ynl.h @@ -12,6 +12,7 @@ enum ynl_error_code { YNL_ERROR_NONE = 0, __YNL_ERRNO_END = 4096, YNL_ERROR_INTERNAL, + YNL_ERROR_DUMP_INTER, YNL_ERROR_EXPECT_ACK, YNL_ERROR_EXPECT_MSG, YNL_ERROR_UNEXPECT_MSG, @@ -19,6 +20,8 @@ enum ynl_error_code { YNL_ERROR_ATTR_INVALID, YNL_ERROR_UNKNOWN_NTF, YNL_ERROR_INV_RESP, + YNL_ERROR_INPUT_INVALID, + YNL_ERROR_INPUT_TOO_BIG, }; /** @@ -58,7 +61,7 @@ struct ynl_sock { /* private: */ const struct ynl_family *family; - struct mnl_socket *sock; + int socket; __u32 seq; __u32 portid; __u16 family_id; @@ -73,7 +76,7 @@ struct ynl_sock { struct ynl_ntf_base_type **ntf_last_next; struct nlmsghdr *nlh; - struct ynl_policy_nest *req_policy; + const struct ynl_policy_nest *req_policy; unsigned char *tx_buf; unsigned char *rx_buf; unsigned char raw_buf[]; @@ -88,6 +91,18 @@ void ynl_sock_destroy(struct ynl_sock *ys); !ynl_dump_obj_is_last(iter); \ iter = ynl_dump_obj_next(iter)) +/** + * ynl_dump_empty() - does the dump have no entries + * @dump: pointer to the dump list, as returned by a dump call + * + * Check if the dump is empty, i.e. contains no objects. + * Dump calls return NULL on error, and terminator element if empty. + */ +static inline bool ynl_dump_empty(void *dump) +{ + return dump == (void *)YNL_LIST_END; +} + int ynl_subscribe(struct ynl_sock *ys, const char *grp_name); int ynl_socket_get_fd(struct ynl_sock *ys); int ynl_ntf_check(struct ynl_sock *ys); diff --git a/tools/net/ynl/lib/ynl.py b/tools/net/ynl/lib/ynl.py deleted file mode 100644 index 1e10512b2117..000000000000 --- a/tools/net/ynl/lib/ynl.py +++ /dev/null @@ -1,824 +0,0 @@ -# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause - -from collections import namedtuple -import functools -import os -import random -import socket -import struct -from struct import Struct -import yaml -import ipaddress -import uuid - -from .nlspec import SpecFamily - -# -# Generic Netlink code which should really be in some library, but I can't quickly find one. -# - - -class Netlink: - # Netlink socket - SOL_NETLINK = 270 - - NETLINK_ADD_MEMBERSHIP = 1 - NETLINK_CAP_ACK = 10 - NETLINK_EXT_ACK = 11 - NETLINK_GET_STRICT_CHK = 12 - - # Netlink message - NLMSG_ERROR = 2 - NLMSG_DONE = 3 - - NLM_F_REQUEST = 1 - NLM_F_ACK = 4 - NLM_F_ROOT = 0x100 - NLM_F_MATCH = 0x200 - - NLM_F_REPLACE = 0x100 - NLM_F_EXCL = 0x200 - NLM_F_CREATE = 0x400 - NLM_F_APPEND = 0x800 - - NLM_F_CAPPED = 0x100 - NLM_F_ACK_TLVS = 0x200 - - NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH - - NLA_F_NESTED = 0x8000 - NLA_F_NET_BYTEORDER = 0x4000 - - NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER - - # Genetlink defines - NETLINK_GENERIC = 16 - - GENL_ID_CTRL = 0x10 - - # nlctrl - CTRL_CMD_GETFAMILY = 3 - - CTRL_ATTR_FAMILY_ID = 1 - CTRL_ATTR_FAMILY_NAME = 2 - CTRL_ATTR_MAXATTR = 5 - CTRL_ATTR_MCAST_GROUPS = 7 - - CTRL_ATTR_MCAST_GRP_NAME = 1 - CTRL_ATTR_MCAST_GRP_ID = 2 - - # Extack types - NLMSGERR_ATTR_MSG = 1 - NLMSGERR_ATTR_OFFS = 2 - NLMSGERR_ATTR_COOKIE = 3 - NLMSGERR_ATTR_POLICY = 4 - NLMSGERR_ATTR_MISS_TYPE = 5 - NLMSGERR_ATTR_MISS_NEST = 6 - - -class NlError(Exception): - def __init__(self, nl_msg): - self.nl_msg = nl_msg - - def __str__(self): - return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" - - -class NlAttr: - ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) - type_formats = { - 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), - 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), - 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), - 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), - 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), - 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), - 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), - 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) - } - - def __init__(self, raw, offset): - self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) - self.type = self._type & ~Netlink.NLA_TYPE_MASK - self.is_nest = self._type & Netlink.NLA_F_NESTED - self.payload_len = self._len - self.full_len = (self.payload_len + 3) & ~3 - self.raw = raw[offset + 4 : offset + self.payload_len] - - @classmethod - def get_format(cls, attr_type, byte_order=None): - format = cls.type_formats[attr_type] - if byte_order: - return format.big if byte_order == "big-endian" \ - else format.little - return format.native - - @classmethod - def formatted_string(cls, raw, display_hint): - if display_hint == 'mac': - formatted = ':'.join('%02x' % b for b in raw) - elif display_hint == 'hex': - formatted = bytes.hex(raw, ' ') - elif display_hint in [ 'ipv4', 'ipv6' ]: - formatted = format(ipaddress.ip_address(raw)) - elif display_hint == 'uuid': - formatted = str(uuid.UUID(bytes=raw)) - else: - formatted = raw - return formatted - - def as_scalar(self, attr_type, byte_order=None): - format = self.get_format(attr_type, byte_order) - return format.unpack(self.raw)[0] - - def as_auto_scalar(self, attr_type, byte_order=None): - if len(self.raw) != 4 and len(self.raw) != 8: - raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") - real_type = attr_type[0] + str(len(self.raw) * 8) - format = self.get_format(real_type, byte_order) - return format.unpack(self.raw)[0] - - def as_strz(self): - return self.raw.decode('ascii')[:-1] - - def as_bin(self): - return self.raw - - def as_c_array(self, type): - format = self.get_format(type) - return [ x[0] for x in format.iter_unpack(self.raw) ] - - def as_struct(self, members): - value = dict() - offset = 0 - for m in members: - # TODO: handle non-scalar members - if m.type == 'binary': - decoded = self.raw[offset : offset + m['len']] - offset += m['len'] - elif m.type in NlAttr.type_formats: - format = self.get_format(m.type, m.byte_order) - [ decoded ] = format.unpack_from(self.raw, offset) - offset += format.size - if m.display_hint: - decoded = self.formatted_string(decoded, m.display_hint) - value[m.name] = decoded - return value - - def __repr__(self): - return f"[type:{self.type} len:{self._len}] {self.raw}" - - -class NlAttrs: - def __init__(self, msg, offset=0): - self.attrs = [] - - while offset < len(msg): - attr = NlAttr(msg, offset) - offset += attr.full_len - self.attrs.append(attr) - - def __iter__(self): - yield from self.attrs - - def __repr__(self): - msg = '' - for a in self.attrs: - if msg: - msg += '\n' - msg += repr(a) - return msg - - -class NlMsg: - def __init__(self, msg, offset, attr_space=None): - self.hdr = msg[offset : offset + 16] - - self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ - struct.unpack("IHHII", self.hdr) - - self.raw = msg[offset + 16 : offset + self.nl_len] - - self.error = 0 - self.done = 0 - - extack_off = None - if self.nl_type == Netlink.NLMSG_ERROR: - self.error = struct.unpack("i", self.raw[0:4])[0] - self.done = 1 - extack_off = 20 - elif self.nl_type == Netlink.NLMSG_DONE: - self.done = 1 - extack_off = 4 - - self.extack = None - if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: - self.extack = dict() - extack_attrs = NlAttrs(self.raw[extack_off:]) - for extack in extack_attrs: - if extack.type == Netlink.NLMSGERR_ATTR_MSG: - self.extack['msg'] = extack.as_strz() - elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: - self.extack['miss-type'] = extack.as_scalar('u32') - elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: - self.extack['miss-nest'] = extack.as_scalar('u32') - elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: - self.extack['bad-attr-offs'] = extack.as_scalar('u32') - else: - if 'unknown' not in self.extack: - self.extack['unknown'] = [] - self.extack['unknown'].append(extack) - - if attr_space: - # We don't have the ability to parse nests yet, so only do global - if 'miss-type' in self.extack and 'miss-nest' not in self.extack: - miss_type = self.extack['miss-type'] - if miss_type in attr_space.attrs_by_val: - spec = attr_space.attrs_by_val[miss_type] - desc = spec['name'] - if 'doc' in spec: - desc += f" ({spec['doc']})" - self.extack['miss-type'] = desc - - def cmd(self): - return self.nl_type - - def __repr__(self): - msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" - if self.error: - msg += '\terror: ' + str(self.error) - if self.extack: - msg += '\textack: ' + repr(self.extack) - return msg - - -class NlMsgs: - def __init__(self, data, attr_space=None): - self.msgs = [] - - offset = 0 - while offset < len(data): - msg = NlMsg(data, offset, attr_space=attr_space) - offset += msg.nl_len - self.msgs.append(msg) - - def __iter__(self): - yield from self.msgs - - -genl_family_name_to_id = None - - -def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): - # we prepend length in _genl_msg_finalize() - if seq is None: - seq = random.randint(1, 1024) - nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) - genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) - return nlmsg + genlmsg - - -def _genl_msg_finalize(msg): - return struct.pack("I", len(msg) + 4) + msg - - -def _genl_load_families(): - with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: - sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) - - msg = _genl_msg(Netlink.GENL_ID_CTRL, - Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, - Netlink.CTRL_CMD_GETFAMILY, 1) - msg = _genl_msg_finalize(msg) - - sock.send(msg, 0) - - global genl_family_name_to_id - genl_family_name_to_id = dict() - - while True: - reply = sock.recv(128 * 1024) - nms = NlMsgs(reply) - for nl_msg in nms: - if nl_msg.error: - print("Netlink error:", nl_msg.error) - return - if nl_msg.done: - return - - gm = GenlMsg(nl_msg) - fam = dict() - for attr in NlAttrs(gm.raw): - if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: - fam['id'] = attr.as_scalar('u16') - elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: - fam['name'] = attr.as_strz() - elif attr.type == Netlink.CTRL_ATTR_MAXATTR: - fam['maxattr'] = attr.as_scalar('u32') - elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: - fam['mcast'] = dict() - for entry in NlAttrs(attr.raw): - mcast_name = None - mcast_id = None - for entry_attr in NlAttrs(entry.raw): - if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: - mcast_name = entry_attr.as_strz() - elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: - mcast_id = entry_attr.as_scalar('u32') - if mcast_name and mcast_id is not None: - fam['mcast'][mcast_name] = mcast_id - if 'name' in fam and 'id' in fam: - genl_family_name_to_id[fam['name']] = fam - - -class GenlMsg: - def __init__(self, nl_msg): - self.nl = nl_msg - self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) - self.raw = nl_msg.raw[4:] - - def cmd(self): - return self.genl_cmd - - def __repr__(self): - msg = repr(self.nl) - msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" - for a in self.raw_attrs: - msg += '\t\t' + repr(a) + '\n' - return msg - - -class NetlinkProtocol: - def __init__(self, family_name, proto_num): - self.family_name = family_name - self.proto_num = proto_num - - def _message(self, nl_type, nl_flags, seq=None): - if seq is None: - seq = random.randint(1, 1024) - nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) - return nlmsg - - def message(self, flags, command, version, seq=None): - return self._message(command, flags, seq) - - def _decode(self, nl_msg): - return nl_msg - - def decode(self, ynl, nl_msg): - msg = self._decode(nl_msg) - fixed_header_size = 0 - if ynl: - op = ynl.rsp_by_value[msg.cmd()] - fixed_header_size = ynl._fixed_header_size(op.fixed_header) - msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) - return msg - - def get_mcast_id(self, mcast_name, mcast_groups): - if mcast_name not in mcast_groups: - raise Exception(f'Multicast group "{mcast_name}" not present in the spec') - return mcast_groups[mcast_name].value - - -class GenlProtocol(NetlinkProtocol): - def __init__(self, family_name): - super().__init__(family_name, Netlink.NETLINK_GENERIC) - - global genl_family_name_to_id - if genl_family_name_to_id is None: - _genl_load_families() - - self.genl_family = genl_family_name_to_id[family_name] - self.family_id = genl_family_name_to_id[family_name]['id'] - - def message(self, flags, command, version, seq=None): - nlmsg = self._message(self.family_id, flags, seq) - genlmsg = struct.pack("BBH", command, version, 0) - return nlmsg + genlmsg - - def _decode(self, nl_msg): - return GenlMsg(nl_msg) - - def get_mcast_id(self, mcast_name, mcast_groups): - if mcast_name not in self.genl_family['mcast']: - raise Exception(f'Multicast group "{mcast_name}" not present in the family') - return self.genl_family['mcast'][mcast_name] - - -# -# YNL implementation details. -# - - -class YnlFamily(SpecFamily): - def __init__(self, def_path, schema=None, process_unknown=False): - super().__init__(def_path, schema) - - self.include_raw = False - self.process_unknown = process_unknown - - try: - if self.proto == "netlink-raw": - self.nlproto = NetlinkProtocol(self.yaml['name'], - self.yaml['protonum']) - else: - self.nlproto = GenlProtocol(self.yaml['name']) - except KeyError: - raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") - - self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) - self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) - self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) - self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) - - self.async_msg_ids = set() - self.async_msg_queue = [] - - for msg in self.msgs.values(): - if msg.is_async: - self.async_msg_ids.add(msg.rsp_value) - - for op_name, op in self.ops.items(): - bound_f = functools.partial(self._op, op_name) - setattr(self, op.ident_name, bound_f) - - - def ntf_subscribe(self, mcast_name): - mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) - self.sock.bind((0, 0)) - self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, - mcast_id) - - def _add_attr(self, space, name, value): - try: - attr = self.attr_sets[space][name] - except KeyError: - raise Exception(f"Space '{space}' has no attribute '{name}'") - nl_type = attr.value - if attr["type"] == 'nest': - nl_type |= Netlink.NLA_F_NESTED - attr_payload = b'' - for subname, subvalue in value.items(): - attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) - elif attr["type"] == 'flag': - attr_payload = b'' - elif attr["type"] == 'string': - attr_payload = str(value).encode('ascii') + b'\x00' - elif attr["type"] == 'binary': - if isinstance(value, bytes): - attr_payload = value - elif isinstance(value, str): - attr_payload = bytes.fromhex(value) - else: - raise Exception(f'Unknown type for binary attribute, value: {value}') - elif attr.is_auto_scalar: - scalar = int(value) - real_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') - format = NlAttr.get_format(real_type, attr.byte_order) - attr_payload = format.pack(int(value)) - elif attr['type'] in NlAttr.type_formats: - format = NlAttr.get_format(attr['type'], attr.byte_order) - attr_payload = format.pack(int(value)) - elif attr['type'] in "bitfield32": - attr_payload = struct.pack("II", int(value["value"]), int(value["selector"])) - else: - raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') - - pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) - return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad - - def _decode_enum(self, raw, attr_spec): - enum = self.consts[attr_spec['enum']] - if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): - i = 0 - value = set() - while raw: - if raw & 1: - value.add(enum.entries_by_val[i].name) - raw >>= 1 - i += 1 - else: - value = enum.entries_by_val[raw].name - return value - - def _decode_binary(self, attr, attr_spec): - if attr_spec.struct_name: - members = self.consts[attr_spec.struct_name] - decoded = attr.as_struct(members) - for m in members: - if m.enum: - decoded[m.name] = self._decode_enum(decoded[m.name], m) - elif attr_spec.sub_type: - decoded = attr.as_c_array(attr_spec.sub_type) - else: - decoded = attr.as_bin() - if attr_spec.display_hint: - decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint) - return decoded - - def _decode_array_nest(self, attr, attr_spec): - decoded = [] - offset = 0 - while offset < len(attr.raw): - item = NlAttr(attr.raw, offset) - offset += item.full_len - - subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) - decoded.append({ item.type: subattrs }) - return decoded - - def _decode_unknown(self, attr): - if attr.is_nest: - return self._decode(NlAttrs(attr.raw), None) - else: - return attr.as_bin() - - def _rsp_add(self, rsp, name, is_multi, decoded): - if is_multi == None: - if name in rsp and type(rsp[name]) is not list: - rsp[name] = [rsp[name]] - is_multi = True - else: - is_multi = False - - if not is_multi: - rsp[name] = decoded - elif name in rsp: - rsp[name].append(decoded) - else: - rsp[name] = [decoded] - - def _resolve_selector(self, attr_spec, vals): - sub_msg = attr_spec.sub_message - if sub_msg not in self.sub_msgs: - raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") - sub_msg_spec = self.sub_msgs[sub_msg] - - selector = attr_spec.selector - if selector not in vals: - raise Exception(f"There is no value for {selector} to resolve '{attr_spec.name}'") - value = vals[selector] - if value not in sub_msg_spec.formats: - raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") - - spec = sub_msg_spec.formats[value] - return spec - - def _decode_sub_msg(self, attr, attr_spec, rsp): - msg_format = self._resolve_selector(attr_spec, rsp) - decoded = {} - offset = 0 - if msg_format.fixed_header: - decoded.update(self._decode_fixed_header(attr, msg_format.fixed_header)); - offset = self._fixed_header_size(msg_format.fixed_header) - if msg_format.attr_set: - if msg_format.attr_set in self.attr_sets: - subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) - decoded.update(subdict) - else: - raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") - return decoded - - def _decode(self, attrs, space): - if space: - attr_space = self.attr_sets[space] - rsp = dict() - for attr in attrs: - try: - attr_spec = attr_space.attrs_by_val[attr.type] - except (KeyError, UnboundLocalError): - if not self.process_unknown: - raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") - attr_name = f"UnknownAttr({attr.type})" - self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) - continue - - if attr_spec["type"] == 'nest': - subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) - decoded = subdict - elif attr_spec["type"] == 'string': - decoded = attr.as_strz() - elif attr_spec["type"] == 'binary': - decoded = self._decode_binary(attr, attr_spec) - elif attr_spec["type"] == 'flag': - decoded = True - elif attr_spec.is_auto_scalar: - decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) - elif attr_spec["type"] in NlAttr.type_formats: - decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) - if 'enum' in attr_spec: - decoded = self._decode_enum(decoded, attr_spec) - elif attr_spec["type"] == 'array-nest': - decoded = self._decode_array_nest(attr, attr_spec) - elif attr_spec["type"] == 'bitfield32': - value, selector = struct.unpack("II", attr.raw) - if 'enum' in attr_spec: - value = self._decode_enum(value, attr_spec) - selector = self._decode_enum(selector, attr_spec) - decoded = {"value": value, "selector": selector} - elif attr_spec["type"] == 'sub-message': - decoded = self._decode_sub_msg(attr, attr_spec, rsp) - else: - if not self.process_unknown: - raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') - decoded = self._decode_unknown(attr) - - self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) - - return rsp - - def _decode_extack_path(self, attrs, attr_set, offset, target): - for attr in attrs: - try: - attr_spec = attr_set.attrs_by_val[attr.type] - except KeyError: - raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") - if offset > target: - break - if offset == target: - return '.' + attr_spec.name - - if offset + attr.full_len <= target: - offset += attr.full_len - continue - if attr_spec['type'] != 'nest': - raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") - offset += 4 - subpath = self._decode_extack_path(NlAttrs(attr.raw), - self.attr_sets[attr_spec['nested-attributes']], - offset, target) - if subpath is None: - return None - return '.' + attr_spec.name + subpath - - return None - - def _decode_extack(self, request, op, extack): - if 'bad-attr-offs' not in extack: - return - - msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set)) - offset = 20 + self._fixed_header_size(op.fixed_header) - path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, - extack['bad-attr-offs']) - if path: - del extack['bad-attr-offs'] - extack['bad-attr'] = path - - def _fixed_header_size(self, name): - if name: - fixed_header_members = self.consts[name].members - size = 0 - for m in fixed_header_members: - if m.type in ['pad', 'binary']: - size += m.len - else: - format = NlAttr.get_format(m.type, m.byte_order) - size += format.size - return size - else: - return 0 - - def _decode_fixed_header(self, msg, name): - fixed_header_members = self.consts[name].members - fixed_header_attrs = dict() - offset = 0 - for m in fixed_header_members: - value = None - if m.type == 'pad': - offset += m.len - elif m.type == 'binary': - value = msg.raw[offset : offset + m.len] - offset += m.len - else: - format = NlAttr.get_format(m.type, m.byte_order) - [ value ] = format.unpack_from(msg.raw, offset) - offset += format.size - if value is not None: - if m.enum: - value = self._decode_enum(value, m) - fixed_header_attrs[m.name] = value - return fixed_header_attrs - - def handle_ntf(self, decoded): - msg = dict() - if self.include_raw: - msg['raw'] = decoded - op = self.rsp_by_value[decoded.cmd()] - attrs = self._decode(decoded.raw_attrs, op.attr_set.name) - if op.fixed_header: - attrs.update(self._decode_fixed_header(decoded, op.fixed_header)) - - msg['name'] = op['name'] - msg['msg'] = attrs - self.async_msg_queue.append(msg) - - def check_ntf(self): - while True: - try: - reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) - except BlockingIOError: - return - - nms = NlMsgs(reply) - for nl_msg in nms: - if nl_msg.error: - print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) - print(nl_msg) - continue - if nl_msg.done: - print("Netlink done while checking for ntf!?") - continue - - decoded = self.nlproto.decode(self, nl_msg) - if decoded.cmd() not in self.async_msg_ids: - print("Unexpected msg id done while checking for ntf", decoded) - continue - - self.handle_ntf(decoded) - - def operation_do_attributes(self, name): - """ - For a given operation name, find and return a supported - set of attributes (as a dict). - """ - op = self.find_operation(name) - if not op: - return None - - return op['do']['request']['attributes'].copy() - - def _op(self, method, vals, flags=None, dump=False): - op = self.ops[method] - - nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK - for flag in flags or []: - nl_flags |= flag - if dump: - nl_flags |= Netlink.NLM_F_DUMP - - req_seq = random.randint(1024, 65535) - msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) - fixed_header_members = [] - if op.fixed_header: - fixed_header_members = self.consts[op.fixed_header].members - for m in fixed_header_members: - value = vals.pop(m.name) if m.name in vals else 0 - if m.type == 'pad': - msg += bytearray(m.len) - elif m.type == 'binary': - msg += bytes.fromhex(value) - else: - format = NlAttr.get_format(m.type, m.byte_order) - msg += format.pack(value) - for name, value in vals.items(): - msg += self._add_attr(op.attr_set.name, name, value) - msg = _genl_msg_finalize(msg) - - self.sock.send(msg, 0) - - done = False - rsp = [] - while not done: - reply = self.sock.recv(128 * 1024) - nms = NlMsgs(reply, attr_space=op.attr_set) - for nl_msg in nms: - if nl_msg.extack: - self._decode_extack(msg, op, nl_msg.extack) - - if nl_msg.error: - raise NlError(nl_msg) - if nl_msg.done: - if nl_msg.extack: - print("Netlink warning:") - print(nl_msg) - done = True - break - - decoded = self.nlproto.decode(self, nl_msg) - - # Check if this is a reply to our request - if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value: - if decoded.cmd() in self.async_msg_ids: - self.handle_ntf(decoded) - continue - else: - print('Unexpected message: ' + repr(decoded)) - continue - - rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) - if op.fixed_header: - rsp_msg.update(self._decode_fixed_header(decoded, op.fixed_header)) - rsp.append(rsp_msg) - - if not rsp: - return None - if not dump and len(rsp) == 1: - return rsp[0] - return rsp - - def do(self, method, vals, flags=None): - return self._op(method, vals, flags) - - def dump(self, method, vals): - return self._op(method, vals, [], dump=True) |