models
TOC
- Attributes:
- 🅰 NodeClassType
- 🅰 NodeParams - flag for shared module, which will be built once and cached out.
- 🅰 NodeInstance - flag for shared module, which will be built once and cached out.
- 🅰 NoCallSkipFlag - flag for shared module, which will be built once and cached out.
- 🅰 SpecialFlag - flag for shared module, which will be built once and cached out.
- 🅰 REUSE_FLAG - flag for shared module, which will be built once and cached out.
- 🅰 INTER_FLAG - flag for intermediate module, which will be built from scratch if need.
- 🅰 CLASS_FLAG - flag for use module class itself, instead of its instance.
- 🅰 REFER_FLAG - flag for refer a value from top level of config.
- 🅰 OTHER_FLAG - default flag.
- 🅰 FLAG_PATTERN - flag for no call, which will be skipped.
- 🅰 DO_NOT_CALL_KEY - flag for no call, which will be skipped.
- 🅰 IS_PARSING - flag for parsing
- 🅰 SPECIAL_FLAGS - hook flags.
- 🅰 HOOK_FLAGS - hook flags.
- 🅰 ConfigNode - ConfigNode type in parsing phase
- 🅰 NodeType - Type of ModuleNode
- 🅰 _dispatch_module_node - type:ignore
- 🅰 _dispatch_argument_hook - type:ignore
- Functions:
- 🅵 silent - Disables logging of build messages.
- 🅵 _is_special - Determine if the given string begin with target special flag.
- 🅵 _str_to_target - Imports a module or retrieves a class/function from a module
- 🅵 register_special_flag - Register a new special flag for module nodes.
- 🅵 register_argument_hook - Register a new argument hook.
- Classes:
- 🅲 ModuleNode - A base class representing
LazyConfig
which is similar todetectron2.config.lazy.LazyCall
. - 🅲 InterNode - Intermediate module node. More details see
config.overview
. - 🅲 ConfigHookNode - Wrapper for
Hook
orConfigArgumentHook
. - 🅲 ReusedNode - A subclass of InterNode representing a reused module node.
- 🅲 ClassNode -
ClassNode
returns the wrapped class, function or module itself instead of calling them. - 🅲 ConfigArgumentHook - An abstract base class for configuration argument hooks.
- 🅲 GetAttr - A subclass of ConfigArgumentHook for getting attributes.
- 🅲 VariableReference - A subclass of ClassNode for variable references.
- 🅲 ModuleWrapper
- 🅲 ModuleNode - A base class representing
Attributes
🅰 NodeClassType
NodeClassType = Type[Any]
🅰 NodeParams
NodeParams = Dict[str, Any] #flag for shared module, which will be built once and cached out.
🅰 NodeInstance
NodeInstance = object #flag for shared module, which will be built once and cached out.
🅰 NoCallSkipFlag
NoCallSkipFlag = Self #flag for shared module, which will be built once and cached out.
🅰 SpecialFlag
SpecialFlag = Literal["@", "!", "$", "&", ""] #flag for shared module, which will be built once and cached out.
🅰 REUSE_FLAG
REUSE_FLAG: Literal["@"] = "@" #flag for shared module, which will be built once and cached out.
🅰 INTER_FLAG
INTER_FLAG: Literal["!"] = "!" #flag for intermediate module, which will be built from scratch if need.
🅰 CLASS_FLAG
CLASS_FLAG: Literal["$"] = "$" #flag for use module class itself, instead of its instance.
🅰 REFER_FLAG
REFER_FLAG: Literal["&"] = "&" #flag for refer a value from top level of config.
🅰 OTHER_FLAG
OTHER_FLAG: Literal[""] = "" #default flag.
🅰 FLAG_PATTERN
FLAG_PATTERN = re.compile("^([@!$&])(.*)$") #flag for no call, which will be skipped.
🅰 DO_NOT_CALL_KEY
DO_NOT_CALL_KEY = """__no_call__""" #flag for no call, which will be skipped.
🅰 IS_PARSING
IS_PARSING = True #flag for parsing
🅰 SPECIAL_FLAGS
SPECIAL_FLAGS = [OTHER_FLAG, INTER_FLAG, REUSE_FLAG, CLASS_FLAG, REFER_FLAG] #hook flags.
🅰 HOOK_FLAGS
HOOK_FLAGS = ["@", "."] #hook flags.
🅰 ConfigNode
ConfigNode = Union[ModuleNode, ConfigArgumentHook] #ConfigNode type in parsing phase
🅰 NodeType
NodeType = Type[ModuleNode] #Type of ModuleNode
🅰 _dispatch_module_node
_dispatch_module_node
_dispatch_module_node: dict[SpecialFlag, NodeType] = {
OTHER_FLAG: ModuleNode,
REUSE_FLAG: ReusedNode,
INTER_FLAG: InterNode,
CLASS_FLAG: ClassNode,
REFER_FLAG: VariableReference,
} #type:ignore
🅰 _dispatch_argument_hook
_dispatch_argument_hook: dict[str, Type[ConfigArgumentHook]] = {
"@": ConfigArgumentHook,
".": GetAttr,
} #type:ignore
Functions
🅵 silent
def silent() -> None:
workspace.excore_log_build_message = False
Disables logging of build messages.
🅵 _is_special
_is_special
def _is_special(k: str) -> tuple[str, SpecialFlag]:
match = FLAG_PATTERN.match(k)
if match:
logger.ex(f"Find match `{match}`.")
return match.group(2), match.group(1)
logger.ex("No Match.")
return k, ""
Determine if the given string begin with target special flag.
`@` denotes reused module, which will only be built once and cached out. `!` denotes intermediate module, which will be built from scratch if need. `$` denotes use module class itself, instead of its instance. `&` denotes use refer a value from top level of config. And other registered user defined special flag, see `register_special_flag`. All default flags see `SPECIAL_FLAGS`
Parameters:
- k (str): The input string to check.
Returns:
🅵 _str_to_target
_str_to_target
def _str_to_target(
module_name: str,
) -> ModuleType | NodeClassType | FunctionType:
module_names = module_name.split(".")
if len(module_names) == 1:
return importlib.import_module(module_names[0])
target_name = module_names.pop(-1)
try:
module = importlib.import_module(".".join(module_names))
except ModuleNotFoundError as exc:
raise StrToClassError(
f"Cannot import such module: `{'.'.join(module_names)}`"
) from exc
try:
module = getattr(module, target_name)
except AttributeError as exc:
raise StrToClassError(
f"Cannot find such module `{target_name}` form `{'.'.join(module_names)}`"
) from exc
return module
Imports a module or retrieves a class/function from a module
based on the provided module name.
Parameters:
- module_name (str): The name of the module or the module path with the target class/function.
Returns:
- ModuleType | NodeClassType | FunctionType: The imported module, class or function.
Raises:
- StrToClassError: If the module or target cannot be imported or found.
🅵 register_special_flag
register_special_flag
def register_special_flag(
flag: str, node_type: NodeType, force: bool = False
) -> None:
if not force and flag in SPECIAL_FLAGS:
raise ValueError(f"Special flag `{flag}` already exist.")
SPECIAL_FLAGS.append(flag)
global FLAG_PATTERN
FLAG_PATTERN = re.compile(f"^([{''.join(SPECIAL_FLAGS)}])(.*)$")
_dispatch_module_node[flag] = node_type
logger.ex(
f"Register new module node `{node_type}` with special flag `{flag}.`"
)
Register a new special flag for module nodes.
Parameters:
- flag (str): The special flag to register.
- node_type (NodeType): The type of node associated with the flag.
- force (bool): Whether to force registration if the flag already exists. Defaults to False.
Raises:
- ValueError: If the flag already exists and force is False.
🅵 register_argument_hook
register_argument_hook
def register_argument_hook(
flag: str, node_type: Type[ConfigArgumentHook], force: bool = False
) -> None:
if not force and flag in HOOK_FLAGS:
raise ValueError(f"Special flag `{flag}` already exist.")
HOOK_FLAGS.append(flag)
_dispatch_argument_hook[flag] = node_type
logger.ex(
f"Register new hook node `{node_type}` with special flag `{flag}.`"
)
Register a new argument hook.
Parameters:
- flag (str): The flag associated with the hook.
- node_type (Type[ConfigArgumentHook]): The type of hook to register.
- force (bool): Whether to force registration if the flag already exists. Defaults to False.
Raises:
- ValueError: If the flag already exists and force is False.
Classes
🅲 ModuleNode
@dataclass
class ModuleNode(dict):
target: Any = None
_no_call: bool = field(default=False, repr=False)
priority: int = field(default=0, repr=False)
A base class representing `LazyConfig` which is similar to `detectron2.config.lazy.LazyCall`.
Wrap a class, function or python module and its parameters util you want to call it.
Attributes:
- target (Any): The class or module associated with the node.
- _no_call (bool): Flag to indicate if the node should not be called when you actually call it. Usually used with function so in the config parsing phase the `target` will not be called. Defaults to False.
- priority (int): Priority level of the node, used in parsing phase.
Examples:
# Store class
node = ModuleNode(MyClass).add(a=1, b=2)
instance = node()
# Store function
node = ModuleNode(my_func).add(a=1, b=2)
result = node()
# Store module
node = ModuleNode(my_module).add(a=1, b=2)
result = node() # module itself
🅼 _update_params
_update_params
def _update_params(self, **params: NodeParams) -> None:
return_params = {}
for k, v in self.items():
if isinstance(v, (ModuleWrapper, ModuleNode)):
v = v()
return_params[k] = v
self.update(params)
self.update(return_params)
Updates the parameters of the node, if any parameter is instance of `ModuleNode`,
it will be called first.
Parameters:
- **params (NodeParams): The parameters to update.
🅼 name
@property
def name(self) -> str:
return self.target.__name__
Property to get the name of the associated class or module.
Returns:
- str: The name of the class or module.
🅼 add
def add(self, **params: NodeParams) -> Self:
self.update(params)
return self
Adds parameters to the node.
Parameters:
- **params: The parameters to add.
Returns:
- Self: The updated node.
🅼 _instantiate
_instantiate
def _instantiate(self) -> NodeInstance:
try:
if ismodule(self.target):
return self.target
module = self.target(**self)
except Exception as exc:
raise ModuleBuildError(
f"Instantiate Error with module {self.target} and arguments {self.items()}"
) from exc
if workspace.excore_log_build_message:
logger.success(
f"Successfully instantiated: {self.target.__name__} with arguments {self.items()}"
)
return module
Instantiates the module, handling errors.
Returns:
- NodeInstance: The instantiated module.
Raises:
- ModuleBuildError: If instantiation fails.
🅼 __call__
__call__
def __call__(self, **params: NodeParams) -> NoCallSkipFlag | NodeInstance:
print(IS_PARSING)
if IS_PARSING and self._no_call:
return self
self._update_params(**params)
self.validate()
module = self._instantiate()
return module
Call the node.
Parameters:
- **params: The parameters for instantiation.
Returns:
- NoCallSkipFlag | NodeInstance: The instantiated module or the node itself if _no_call is True.
🅼 __lshift__
def __lshift__(self, params: NodeParams) -> Self:
if not isinstance(params, dict):
raise TypeError(f"Expect type is dict, but got {type(params)}")
self.update(params)
return self
Updates the node with new parameters.
Parameters:
- params (NodeParams): The parameters to update.
Returns:
- Self: The updated node.
Raises:
- TypeError: If the provided parameters are not a dictionary.
Examples:
node << dict()
🅼 __rshift__
def __rshift__(self, __other: ModuleNode) -> Self:
if not isinstance(__other, ModuleNode):
raise TypeError(f"Expect type is `ModuleNode`, but got {type(__other)}")
__other.update(self)
return self
Merges another node into the current node.
Parameters:
- __other (ModuleNode): The node to merge.
Returns:
- Self: The updated node.
Raises:
- TypeError: If the provided other node is not a ModuleNode.
Examples:
node >> other
🅼 __excore_check_target_type__
@classmethod
def __excore_check_target_type__(cls, target_type: type[ModuleNode]) -> bool:
return False
Checks if the target type do not matches the expected type.
Used in config parsing phase.
Parameters:
- target_type (type[ModuleNode]): The target type to check.
Returns:
- bool: False, as this is a base class method.
🅼 __excore_parse__
@classmethod
def __excore_parse__(
cls, config: ConfigDict, **locals: dict[str, Any]
) -> ModuleNode | None:
return None
User defined parsing logic. Disabled by default.
Parameters:
- config (ConfigDict): The configuration to parse.
- **locals (dict[str, Any]): Additional local variables for parsing.
Returns:
- None | ModuleNode: The parsed node or None.
🅼 from_str
from_str
@classmethod
def from_str(
cls, str_target: str, params: NodeParams | None = None
) -> ModuleNode:
node = cls(_str_to_target(str_target))
if params:
node.update(params)
if node.pop(DO_NOT_CALL_KEY, False):
node._no_call = True
return node
Creates a node from a string target.
Parameters:
- str_target (str): The string target representing the module or class.
- params (NodeParams) (default to
None
): The parameters for the node. Defaults to None.
Returns:
- ModuleNode: The created node.
Examples:
node = ModuleNode.from_str("package.module.class", dict(param1=value1))
The str_target
must be registered in the registry. More details see Registry
.
🅼 from_base_name
from_base_name
@classmethod
def from_base_name(
cls, base: str, name: str, params: NodeParams | None = None
) -> ModuleNode:
try:
cls_name = Registry.get_registry(base)[name]
except KeyError as exc:
raise ModuleBuildError(
f"Failed to find the registered module `{name}` with base registry `{base}`"
) from exc
return cls.from_str(cls_name, params)
Creates a node from a base registry and name.
Parameters:
- base (str): The base registry.
- name (str): The name of the module or class.
- params (NodeParams) (default to
None
): The parameters for the node. Defaults to None.
Returns:
- ModuleNode: The created node.
Raises:
- ModuleBuildError: If the module cannot be found in the registry.
Examples:
>>> node = ModuleNode.from_base_name("Module", "ClassName", dict(param1=value1))
🅼 from_node
from_node
@classmethod
def from_node(cls, _other: ModuleNode) -> ModuleNode:
if _other.__class__.__name__ == cls.__name__:
return _other
node = cls(_other.target) << _other
node._no_call = _other._no_call
return node
Creates a new ModuleNode instance from another ModuleNode instance.
Parameters:
- _other (ModuleNode): The other ModuleNode instance to create from.
Returns:
- ModuleNode: A new ModuleNode instance or the original if they are of the same class.
Examples:
node = ModuleNode.from_node(other_node)
🅼 _inspect_params
_inspect_params
@staticmethod
def _inspect_params(cls: type) -> list[inspect.Parameter]:
signature = inspect.signature(cls.__init__ if isclass(cls) else cls)
params = list(signature.parameters.values())
if isclass(cls):
params = params[1:]
return params
Retrieves the inspect parameter objects of a class or function.
Parameters:
- cls (type): The class or function to inspect.
Returns:
- list[inspect.Parameter]: A list of inspect.Parameter objects.
🅼 validate
validate
def validate(self) -> None:
if not workspace.excore_validate:
return
if ismodule(self.target):
return
missing = []
defaults = []
params = ModuleNode._inspect_params(self.target)
for param in params:
if (
param.default == param.empty
and param.kind
not in [Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD]
and param.name not in self
):
missing.append(param.name)
else:
defaults.append(param.name)
message = f"Validating `{self.target.__name__}` , finding missing parameters: `{missing}` without default values."
if not workspace.excore_manual_set and missing:
raise ModuleValidateError(message)
if missing:
logger.info(message)
for param_name in missing:
logger.info(f"Input value of parameter `{param_name}`:")
value = input()
self[param_name] = DictAction._parse_iterable(value)
Validate the parameters of the ModuleNode instance.
This method checks if all required parameters are provided. If validation is globally disabled or the associated class is a module, the method returns immediately.
If any required parameters are missing and manual setting is not allowed, a ModuleValidateError is raised.
If missing parameters are found and manual setting is allowed, the user is prompted to provide values for them. The values will be parsed to `int`, `str`, `list`, `tuple` or `dict`. More details see `DictAction._parse_iterable`.
🅲 InterNode
class InterNode(ModuleNode):
priority: int = 2
Intermediate module node. More details see `config.overview`.
Attributes:
- priority (int): Priority level set to 2.
🅼 __excore_check_target_type__
@classmethod
def __excore_check_target_type__(cls, target_type: type[ModuleNode]) -> bool:
return target_type is ReusedNode
Checks if the target type is ReusedNode.
Parameters:
- target_type (type[ModuleNode]): The target type to check.
Returns:
- bool: True if the target type is ReusedNode, otherwise False.
Same ModuleName
referring to both ReusedNode
and InterNode
are not allowed.
🅲 ConfigHookNode
class ConfigHookNode(ModuleNode):
priority: int = 1
Wrapper for `Hook` or `ConfigArgumentHook`.
Attributes:
- priority (int): Priority level set to 1.
🅼 validate
validate
def validate(self) -> None:
if "node" in self:
raise ModuleValidateError(
f"Parameter `node:{self['node']}` should not exist in `ConfigHookNode`."
)
super().validate()
Validates the node, ensuring 'node' parameter is not present.
Because the `node` should be passed in config parsing phase instead of config definition.
Raises:
- ModuleValidateError: If 'node' parameter is found.
🅼 __call__
@overload
def __call__(
self, **params: NodeParams
) -> NodeInstance | Hook | ConfigArgumentHook:
🅼 __call__
@overload
def __call__(self, **params: dict[str, ModuleNode]) -> ConfigHookNode:
🅼 __call__
def __call__(
self, **params: NodeParams
) -> NodeInstance | Hook | ConfigArgumentHook:
self._update_params(**params)
return self._instantiate()
Calls the node to instantiate the module.
Parameters:
- **params: The parameters for instantiation.
Returns:
- NodeInstance | Hook | ConfigArgumentHook: The instantiated module or hook.
🅲 ReusedNode
class ReusedNode(InterNode):
priority: int = 3
A subclass of InterNode representing a reused module node.
Attributes:
- priority (int): Priority level set to 3.
🅼 __call__
@CacheOut()
def __call__(self, **params: NodeParams) -> NodeInstance | NoCallSkipFlag:
return super().__call__(**params)
Calls the node to instantiate the module, with caching, see `CacheOut`.
Parameters:
- **params: The additional parameters for instantiation.
Returns:
- NodeInstance | NoCallSkipFlag: The instantiated module or the node itself if _no_call is True.
🅼 __excore_check_target_type__
@classmethod
def __excore_check_target_type__(cls, target_type: NodeType) -> bool:
return target_type is InterNode
Checks if the target type is InterNode.
Same `ModuleName` referring to both `ReusedNode` and `InterNode` are not allowed.
Parameters:
- target_type (NodeType): The target type to check.
Returns:
🅲 ClassNode
class ClassNode(ModuleNode):
priority: int = 1
`ClassNode` returns the wrapped class, function or module itself instead of calling them.
Attributes:
- priority (int): Priority level set to 1.
🅼 validate
def validate(self) -> None:
return
Does nothing for class nodes for it should not have any parameters.
🅼 __call__
def __call__(self) -> NodeClassType | FunctionType | ModuleType:
return self.target
Returns the class, function or module itself.
Returns:
- NodeClassType | FunctionType | ModuleType: The class or function.
🅲 ConfigArgumentHook
class ConfigArgumentHook(ABC):
flag: str = "@"
An abstract base class for configuration argument hooks.
Attributes:
- flag (str): The flag associated with the hook.
- node (Callable): The node associated with the hook.
- enabled (bool): Whether apply the hook.
- name (str): The name of the wrapped node.
- _is_initialized (bool): Flag to check if the hook is initialized.
🅼 __init__
__init__
def __init__(self, node: Callable, enabled: bool = True) -> None:
self.node = node
self.enabled = enabled
if not hasattr(node, "name"):
raise ValueError("The `node` must have name attribute.")
self.name = node.name
self._is_initialized = True
Initializes the hook with a node and enabled status.
Parameters:
- node (Callable): The node associated with the hook.
- enabled (bool) (default to
True
): Whether apply the hook. Defaults to True.
Raises:
- ValueError: If the node does not have a name attribute.
🅼 hook
@abstractmethod
def hook(self, **kwargs: Any) -> Any:
raise NotImplementedError(
f"`{self.__class__.__name__}` do not implement `hook` method."
)
Abstract method to implement the hook logic.
Parameters:
- **kwargs: The keyword arguments for the hook.
Returns:
- Any: The result of the hook.
Raises:
- NotImplementedError: If the method is not implemented by a subclass.
🅼 __call__
__call__
@final
def __call__(self, **kwargs: Any) -> Any:
if not getattr(self, "_is_initialized", False):
raise CoreConfigSupportError(
f"Call super().__init__(node) in class `{self.__class__.__name__}`"
)
if self.enabled:
return self.hook(**kwargs)
return self.node(**kwargs)
Calls the hook or the node based on the enabled status.
Parameters:
- **kwargs: The keyword arguments for the call.
Returns:
- Any: The result of the hook or the node call.
Raises:
- CoreConfigSupportError: If the hook is not properly initialized.
🅼 __excore_prepare__
__excore_prepare__
@classmethod
def __excore_prepare__(
cls, node: ConfigNode, hook_info: str, config: ConfigDict
) -> ConfigNode:
hook_name, field = config._get_name_and_field(hook_info)
if not isinstance(hook_name, str):
raise CoreConfigParseError(
f"More than one or none of hooks are found with `{hook_info}`."
)
hook_node = config._get_node_from_name_and_field(
hook_name, field, ConfigHookNode
)[0]
node = hook_node(node=node)
return node
Prepares the hook with configuration.
Parameters:
- node (ConfigNode): The node to wrap.
- hook_info (str): The hook information.
- config (ConfigDict): The configuration dictionary.
Returns:
- ConfigNode: The prepared node.
Raises:
- CoreConfigParseError: If more than one or no hooks are found.
🅲 GetAttr
class GetAttr(ConfigArgumentHook):
flag: str = "."
A subclass of ConfigArgumentHook for getting attributes.
Attributes:
🅼 __init__
def __init__(self, node: ConfigNode, attr: str) -> None:
super().__init__(node)
self.attr = attr
Initializes the hook with a node and attribute.
Parameters:
- node (ConfigNode): The node associated with the hook.
- attr (str): The attribute to get.
🅼 hook
def hook(self, **params: NodeParams) -> Any:
target = self.node(**params)
if isinstance(target, ModuleNode):
raise ModuleBuildError(f"Do not support `{DO_NOT_CALL_KEY}`")
return eval("target." + self.attr)
Implements the hook logic to get the attribute.
Parameters:
- **params: The parameters for the hook.
Returns:
- Any: The value of the attribute.
Raises:
- ModuleBuildError: `DO_NOT_CALL_KEY` is not supported.
🅼 from_list
@classmethod
def from_list(cls, node: ConfigNode, attrs: list[str]) -> ConfigNode:
for attr in attrs:
node = cls(node, attr)
return node
Creates a chain of GetAttr hooks.
Parameters:
- node (ConfigNode): The initial node.
- attrs (list[str]): The list of attributes to get.
Returns:
- ConfigNode: The final node in the chain.
🅼 __excore_prepare__
@classmethod
def __excore_prepare__(
cls, node: ConfigNode, hook_info: str, config: ConfigDict
) -> ConfigNode:
return cls(node, hook_info)
Prepares the hook with configuration.
Parameters:
- node (ConfigNode): The node to warp.
- hook_info (str): The hook information.
- config (ConfigDict): The configuration dictionary.
Returns:
- ConfigNode: The prepared node.
🅲 VariableReference
class VariableReference(ClassNode):
_name: str = None
A subclass of ClassNode for variable references.
Inherited from `ClassNode` is just for convenience.
🅼 __excore_parse__
__excore_parse__
@classmethod
def __excore_parse__(cls, config: ConfigDict, **locals) -> VariableReference:
name = locals["name"]
logger.ex(f"Got `name` {name}.")
parsed_value = config._parse_env_var(name)
if parsed_value != name:
node = cls(parsed_value)
elif name not in config:
raise CoreConfigParseError(f"Can not find reference: {name}.")
else:
node = cls(config[name])
node._name = name
return node
Find the reference and build the node.
Parameters:
- config (ConfigDict): The configuration to parse.
- **locals: Additional local variables for parsing.
Returns:
- VariableReference: The parsed node.
Raises:
- CoreConfigParseError: If the reference cannot be found.
🅼 name
@property
def name(self) -> str:
return self._name
🅲 ModuleWrapper
class ModuleWrapper(dict):
🅼 __init__
__init__
def __init__(
self,
modules: (
dict[str, ConfigNode] | list[ConfigNode] | ConfigNode | None
) = None,
is_dict: bool = False,
) -> None:
if modules is None:
return
self.is_dict = is_dict
if isinstance(modules, (ModuleNode, ConfigArgumentHook)):
self[modules.name] = modules
elif isinstance(modules, dict):
for k, m in modules.items():
if isinstance(m, list):
m = ModuleWrapper(m)
self[k] = m
elif isinstance(modules, list):
for m in modules:
self[self._get_name(m)] = m
if len(self) != len(modules):
raise ValueError("Currently not support for the same class name")
else:
raise TypeError(
f"Expect modules to be `list`, `dict` or `ModuleNode`, but got {type(modules)}"
)
🅼 _get_name
def _get_name(self, m) -> Any:
if hasattr(m, "name"):
return m.name
return m.__class__.__name__
🅼 __lshift__
def __lshift__(self, params: NodeParams) -> None:
if len(self) == 1:
self[next(iter(self.keys()))] << params
else:
raise RuntimeError("Wrapped more than 1 ModuleNode, index first")
🅼 first
def first(self) -> NodeInstance | Self:
if len(self) == 1:
return next(iter(self.values()))
return self
🅼 __getattr__
__getattr__
def __getattr__(self, __name: str) -> Any:
if __name in self.keys():
return self[__name]
raise KeyError(
f"Invalid key `{__name}`, must be one of `{list(self.keys())}`"
)
🅼 __call__
def __call__(self):
if self.is_dict:
return {k: v() for k, v in self.items()}
res = [m() for m in self.values()]
return res[0] if len(res) == 1 else res
🅼 __repr__
def __repr__(self) -> str:
return f"ModuleWrapper{list(self.values())}"