Source code for autoasm.core

# -*- coding: utf-8 -*-
"""
    autoasm
    ~~~~~~~
    A simple dependency injection framework

    :copyright: © 2018 by Neo Ko
    :license: Apache 2.0, see LICENSE for more details.
"""

import asyncio
import enum
import functools
import threading

from . import config
from . import execptions


[docs]class ContextType(enum.Enum): ASYNC = 1 SYNC = 2
[docs]class ServiceType(enum.Enum): ASYNC = 1 SYNC = 2
CoerceType = config.CoerceType class Empty: pass _EMPTY = Empty()
[docs]class Context: """ autoasm core class, create a context to get dependency """ def __init__(self, name, ctx_type=ContextType.SYNC, loop=None): """ :param str name: :param ContextType ctx_type: :param loop: """ self._name = name self._config = {} self._lock = threading.RLock() self._alock = asyncio.Lock() self._dependencies = {} self._async_dependencies = {} self._entities = {} self._type = ctx_type if self._type == ContextType.ASYNC: if isinstance(loop, asyncio.AbstractEventLoop): self._loop = loop else: self._loop = asyncio.get_event_loop() def __str__(self): return self._name @property def config(self): return self._config def configure_from_module(self, name, coerce_type=CoerceType.SAME): cfg = config.ModuleConfig(name, coerce_type) self.configure(cfg) def configure_from_json(self, path): cfg = config.JsonConfig(path) self.configure(cfg)
[docs] def configure(self, cfg): """ :param config.Config cfg: :return: """ self._config.update(cfg.to_dict())
def service(self, key): def wrap(runnable): self._register(runnable, key, service_type=ServiceType.SYNC) return runnable return wrap def async_service(self, key): def wrap(runnable): self._register(runnable, key, service_type=ServiceType.ASYNC) return runnable return wrap def inject(self, *keys): def wrapper(func): @functools.wraps(func) def wrap(*args, **kwargs): instances = {key: self._resolve(key) for key in keys} injection = set(keys) - set(kwargs.keys()) kwargs.update({key: instances[key] for key in injection}) return func(*args, **kwargs) @functools.wraps(func) async def async_wrap(*args, **kwargs): instances = {key: await self._async_resolve(key) for key in keys} injection = set(keys) - set(kwargs.keys()) kwargs.update({key: instances[key] for key in injection}) return await func(*args, **kwargs) if asyncio.iscoroutinefunction(func): return async_wrap else: return wrap return wrapper def _resolve(self, key): with self._lock: if not _is_empty(self._config, key): return self._config.get(key) elif _is_empty(self._entities, key): return self._resolve_unsafe(key) return self._entities.get(key) def _resolve_unsafe(self, key): dependency = self._dependencies.get(key) if not dependency: _msg = "can't find dependency with key {key}" msg = _msg.format(key=key) raise execptions.ServiceNotFound(msg) entity = dependency() self._entities[key] = entity return self._entities[key] async def _async_resolve(self, key): async with self._alock: if not _is_empty(self._config, key): return self._config.get(key) elif _is_empty(self._entities, key): return await self._async_resolve_unsafe(key) return self._entities.get(key) async def _async_resolve_unsafe(self, key): async_dep = self._async_dependencies.get(key) if async_dep: entity = await async_dep() self._entities[key] = entity return self._entities[key] elif self._dependencies.get(key): dep = self._dependencies.get(key) entity = dep() self._entities[key] = entity return self._entities[key] _msg = "can't find dependency with key {key}" msg = _msg.format(key=key) raise execptions.ServiceNotFound(msg) def _register(self, runnable, key, service_type=ServiceType.SYNC): """ :param typing. runnable: :param key: :param ServiceType service_type: :return: """ with self._lock: if service_type == ServiceType.SYNC: if _is_empty(self._dependencies, key): self._dependencies[key] = runnable return elif service_type == ServiceType.ASYNC: if _is_empty(self._async_dependencies, key): self._async_dependencies[key] = runnable return else: raise TypeError('unknown service_type') entity_type = type(self._dependencies.get('key')) _msg = '{key} is registered and type is {entity_type}, ' \ 'but get new registration with {t}' msg = _msg.format(key=key, entity_type=entity_type, t=type(runnable)) raise execptions.ServiceDuplicated(msg)
[docs] def workspace(self, ws): """ :param Workspace ws: :return: """ self._dependencies.update(ws._dependencies) self._async_dependencies.update(ws._async_dependencies) ws.bind(self)
[docs]class Workspace(Context): def __init__(self, name): super().__init__(name) # type: Context self._context = None
[docs] def bind(self, ctx): """ :param Context ctx: :return: """ self._context = ctx
def inject(self, *keys): def _inject(func): @functools.wraps(func) def wrapper(*args, **kwargs): if self._context: return self._context.inject(*keys)(func)(*args, **kwargs) raise execptions.WorkspaceNotBinding() return wrapper return _inject
def _is_empty(container, key): """ :param typing.Mapping container: :param str key: :return: """ return isinstance(container.get(key, _EMPTY), Empty)