registry
TOC
- Attributes:
- 🅰 _name_re - TODO: Maybe some methods need to be cleared.
- 🅰 _private_flag - TODO: Maybe some methods need to be cleared.
- 🅰 _ClassType - TODO: Maybe some methods need to be cleared.
- Functions:
- Classes:
- 🅲 RegistryMeta
- 🅲 Registry - A registry that stores functions and classes by name.
Attributes
🅰 _name_re
_name_re = re.compile("^[A-Za-z0-9_]+$") #TODO: Maybe some methods need to be cleared.
🅰 _private_flag
_private_flag: str = "__" #TODO: Maybe some methods need to be cleared.
🅰 _ClassType
_ClassType = Type[Any] #TODO: Maybe some methods need to be cleared.
Functions
🅵 _is_pure_ascii
def _is_pure_ascii(name: str) -> None:
if not _name_re.match(name):
raise ValueError(
f"Unexpected name, only support ASCII letters, ASCII digits, underscores, and dashes, but got {name}."
)
🅵 _is_function_or_class
def _is_function_or_class(module: Any) -> bool:
return inspect.isfunction(module) or inspect.isclass(module)
🅵 _default_filter_func
def _default_filter_func(values: Sequence[Any]) -> bool:
return all(v for v in values)
🅵 _default_match_func
_default_match_func
def _default_match_func(m: str, base_module: ModuleType) -> bool:
if not m.startswith("__"):
m = getattr(base_module, m)
if inspect.isfunction(m) or inspect.isclass(m):
return True
return False
🅵 _get_module_name
def _get_module_name(m: ModuleType | _ClassType | FunctionType) -> str:
return getattr(m, "__qualname__", m.__name__)
🅵 load_registries
load_registries
def load_registries() -> None:
message = "Please run `excore auto-register` in your command line first!"
if not os.path.exists(workspace.registry_cache_file):
logger.warning(message)
return
Registry.load()
Registry.lock_register()
if not Registry._registry_pool:
logger.critical(f"No module has been registered. {message}")
sys.exit(1)
Classes
🅲 RegistryMeta
class RegistryMeta(type):
_registry_pool: dict[str, Registry] = {}
🅼 __call__
__call__
def __call__(cls, name: str, **kwargs: Any) -> Registry:
_is_pure_ascii(name)
extra_field = kwargs.get("extra_field")
if name in cls._registry_pool:
extra_field = (
[extra_field] if isinstance(extra_field, str) else extra_field
)
target = cls._registry_pool[name]
if (
extra_field
and hasattr(target, "extra_field")
and extra_field != target.extra_field
):
logger.warning(
f"{cls.__name__}: `{name}` has already existed, different arguments will be ignored"
)
return target
instance = super().__call__(name, **kwargs)
if not name.startswith(_private_flag):
cls._registry_pool[name] = instance
return instance
Assert only call `__init__` once
🅲 Registry
class Registry(dict):
_globals: Registry | None = None
_prevent_register: bool = False
extra_info: dict[str, str] = None
__str__ = __repr__
A registry that stores functions and classes by name.
Attributes:
- name (str): The name of the registry.
- extra_field (str|Sequence[str] | None]): A field or fields that can be used to store additional information about each function or class in the registry.
- extra_info (dict[str, list[Any]]): A dictionary that maps each registered name to a list of extra values associated with that name (if any).
- _globals (Registry | None): A static variable that stores a global registry containing all functions and classes registered using Registry.
Examples:
>>> from excore import Registry
>>> MODEL = Registry('Model', extra_field=['is_backbone'])
>>> @MODEL.registry(force=False, is_backbone=True)
... class ResNet:
... ...
🅼 __init__
__init__
def __init__(
self, /, name: str, *, extra_field: str | Sequence[str] | None = None
) -> None:
self.name = name
if extra_field:
self.extra_field = (
[extra_field] if isinstance(extra_field, str) else extra_field
)
self.extra_info = {}
🅼 dump
dump
@classmethod
def dump(cls, update: bool = False) -> None:
import pickle
file_path = workspace.registry_cache_file
if update and os.path.exists(file_path):
with open(file_path, "rb") as f:
cache_to_dump = pickle.load(f)
cache_to_dump.update(cls._registry_pool)
else:
cache_to_dump = cls._registry_pool
with FileLock(file_path + ".lock", timeout=5), open(file_path, "wb") as f:
pickle.dump(cache_to_dump, f)
logger.success(f"Dump registry cache to {workspace.registry_cache_file}!")
🅼 load
load
@classmethod
def load(cls) -> None:
if not os.path.exists(_workspace_config_file):
logger.warning("Please run `excore init` in your command line first!")
sys.exit(1)
file_path = workspace.registry_cache_file
if not os.path.exists(file_path):
logger.critical(
"Registry cache file do not exist! Please run `excore auto-register in your command line first`"
)
sys.exit(1)
import pickle
with FileLock(file_path + ".lock"), open(file_path, "rb") as f:
data = pickle.load(f)
cls._registry_pool.update(data)
🅼 lock_register
@classmethod
def lock_register(cls) -> None:
cls._prevent_register = True
🅼 unlock_register
@classmethod
def unlock_register(cls) -> None:
cls._prevent_register = False
🅼 get_registry
@classmethod
def get_registry(cls, name: str, default: Any = None) -> Registry:
return Registry._registry_pool.get(name, default)
Returns the `Registry` instance with the given name, or `default` if no such
registry exists.
🅼 find
find
@classmethod
@functools.lru_cache(32)
def find(cls, name: str) -> tuple[Any, str] | tuple[None, None]:
for registried_name, registry in Registry._registry_pool.items():
if name in registry:
return registry[name], registried_name
return None, None
Searches all registries for an element with the given name. If found,
returns a tuple containing the element and the name of the registry where it was found; otherwise, returns `(None, None)`.
🅼 make_global
make_global
@classmethod
def make_global(cls) -> Registry:
if cls._globals is not None:
return cls._globals
reg = cls("__global")
for member in Registry._registry_pool.values():
reg.merge(member, force=False)
cls._globals = reg
return reg
Creates a global `Registry` instance that contains all elements from all
other registries. If the global registry already exists, returns it instead of creating a new one.
🅼 __setitem__
def __setitem__(self, k: str, v: Any) -> None:
super().__setitem__(k, v)
🅼 __repr__
def __repr__(self) -> str:
return _create_table(["NAME", "DIR"], [(k, v) for k, v in self.items()])
🅼 register_module
register_module
@overload
def register_module(
self,
module: Callable[..., Any],
force: bool = ...,
_is_str: bool = ...,
**extra_info: Any
) -> Callable[..., Any]:
🅼 register_module
register_module
@overload
def register_module(
self,
module: ModuleType,
force: bool = ...,
_is_str: bool = ...,
**extra_info: Any
) -> ModuleType:
🅼 register_module
register_module
@overload
def register_module(
self,
module: str,
force: bool = ...,
_is_str: Literal[True] = ...,
**extra_info: Any
) -> str:
🅼 register_module
register_module
def register_module(self, module, force=False, _is_str=False, **extra_info):
if Registry._prevent_register:
logger.ex("Registry has been locked!!!")
return module
if not _is_str:
if not (
_is_function_or_class(module) or isinstance(module, ModuleType)
):
raise TypeError(
f"Only support function or class, but got {type(module)}"
)
name = _get_module_name(module)
else:
name = module.split(".")[-1]
if not force and name in self:
raise ValueError(f"The name {name} exists")
if extra_info:
if not hasattr(self, "extra_field"):
raise ValueError(
f"Registry `{self.name}` does not have `extra_field`."
)
for k in extra_info:
if k not in self.extra_field:
raise ValueError(
f"Registry `{self.name}`: 'extra_info' does not has expected key {k}."
)
self.extra_info[name] = [extra_info.get(k) for k in self.extra_field]
elif hasattr(self, "extra_field"):
self.extra_info[name] = [None] * len(self.extra_field)
if not _is_str:
target = (
name
if isinstance(module, ModuleType)
else ".".join([module.__module__, module.__qualname__])
)
else:
target = module
logger.ex(f"Register {name} with {target}.")
self[name] = target
if Registry._globals is not None and not name.startswith(_private_flag):
Registry._globals.register_module(target, force, True, **extra_info)
return module
🅼 register
def register(
self, force: bool = False, **extra_info: Any
) -> Callable[..., Any]:
return functools.partial(self.register_module, force=force, **extra_info)
Decorator that registers a function or class with the current `Registry`.
Any keyword arguments provided are added to the `extra_info` list for the registered element. If `force` is True, overwrites any existing element with the same name.
🅼 register_all
register_all
def register_all(
self,
modules: Sequence[Callable[..., Any]],
extra_info: Sequence[dict[str, Any]] | None = None,
force: bool = False,
_is_str: bool = False,
) -> None:
if Registry._prevent_register:
return
_info = extra_info if extra_info else [{}] * len(modules)
for module, info in zip(modules, _info):
self.register_module(module, force=force, _is_str=_is_str, **info)
Registers multiple functions or classes with the current `Registry`.
If `force` is True, overwrites any existing elements with the same names.
🅼 get_extra_info
get_extra_info
def get_extra_info(self, key: str, name: str) -> Any:
if name not in self.extra_field:
raise ValueError(
f"Expected name to be one of `{self.extra_field}`, but got `{name}`."
)
for target_name, info in zip(self.extra_field, self.extra_info[key]):
if name == target_name:
return info
🅼 merge
merge
def merge(
self, others: Registry | Sequence[Registry], force: bool = False
) -> None:
if not isinstance(others, (list, tuple, Sequence)):
others = [others]
for other in others:
if not isinstance(other, Registry):
raise TypeError(f"Expect `Registry` type, but got {type(other)}")
modules = list(other.values())
self.register_all(modules, force=force, _is_str=True)
Merge the contents of one or more other registries into the current one.
If `force` is True, overwrites any existing elements with the same names.
🅼 filter
filter
def filter(
self,
filter_field: Sequence[str] | str,
filter_func: Callable[[Sequence[Any]], bool] = _default_filter_func,
) -> list[str]:
filter_field = (
[filter_field] if isinstance(filter_field, str) else filter_field
)
filter_idx = [
i for i, name in enumerate(self.extra_field) if name in filter_field
]
out = []
for name in self.keys():
info = self.extra_info[name]
filter_values = [info[idx] for idx in filter_idx]
if filter_func(filter_values):
out.append(name)
out = list(sorted(out))
return out
Returns a sorted list of all names in the registry for which the values of
the given extra field(s) pass a filtering function.
🅼 match
match
def match(
self,
base_module: ModuleType,
match_func: Callable[[str, ModuleType], bool] = _default_match_func,
force: bool = False,
) -> None:
if Registry._prevent_register:
return
matched_modules = [
getattr(base_module, name)
for name in base_module.__dict__
if match_func(name, base_module)
]
matched_modules = list(filter(_is_function_or_class, matched_modules))
logger.ex("matched modules:{}", [i.__name__ for i in matched_modules])
self.register_all(matched_modules, force=force)
Registers all functions or classes from the given module that pass a matching
function. If `match_func` is not provided, uses `_default_match_func`.
🅼 module_table
module_table
def module_table(
self,
filter: Sequence[str] | str | None = None,
select_info: Sequence[str] | str | None = None,
module_list: Sequence[str] | None = None,
**table_kwargs: Any,
) -> str:
if select_info is not None:
select_info = (
[select_info] if isinstance(select_info, str) else select_info
)
for info_key in select_info:
if info_key not in self.extra_field:
raise ValueError(f"Got unexpected info key {info_key}")
else:
select_info = []
all_modules = module_list if module_list else list(self.keys())
if filter:
set_modules: set[str] = set()
filters = [filter] if isinstance(filter, str) else filter
for f in filters:
include_models = fnmatch.filter(all_modules, f)
if len(include_models):
modules = list(set_modules.union(include_models))
else:
modules = all_modules
modules = list(sorted(modules))
table_headers = [f"{item}" for item in [self.name, *select_info]]
if select_info:
select_idx = [
idx
for idx, name in enumerate(self.extra_field)
if name in select_info
]
else:
select_idx = []
table = _create_table(
table_headers,
[
(i, *[self.extra_info[i][idx] for idx in select_idx])
for i in modules
],
**table_kwargs,
)
table = "\n" + table
return table
Returns a table containing information about each registered function or
class, filtered by name and/or extra info fields. `select_info` specifies which extra info fields to include in the table, while `module_list` specifies which modules to include (by default, includes all modules).
🅼 registry_table
registry_table
@classmethod
def registry_table(cls, **table_kwargs) -> str:
table_headers = ["REGISTRY"]
table = _create_table(
table_headers,
list(sorted([[i] for i in cls._registry_pool])),
**table_kwargs
)
table = "\n" + table
return table
Returns a table containing the names of all available registries.