diff --git a/aiter/jit/core.py b/aiter/jit/core.py index 098af6ce6a..68c09aeb73 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -190,7 +190,7 @@ def update_config_files(self, file_path: str, merge_name: str): for i, path in enumerate(path_list): if not os.path.exists(path): - logger.info(f"path {i+1}: {path} (not exist)") + logger.info(f"path {i + 1}: {path} (not exist)") continue df = pd.read_csv(path) @@ -916,7 +916,7 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources): def FinalFunc(): logger.info( - f"\033[32mfinish build [{md_name}], cost {time.perf_counter()-startTS:.1f}s \033[0m" + f"\033[32mfinish build [{md_name}], cost {time.perf_counter() - startTS:.1f}s \033[0m" ) mp_lock(lockPath=lock_path, MainFunc=MainFunc, FinalFunc=FinalFunc) @@ -1130,6 +1130,8 @@ def _ctypes_call(func, fc_name, md_name): _cache = {} _arg_checked = False + _sig = inspect.signature(func) + _hints = typing.get_type_hints(func) def _ensure_loaded(): if _cache: @@ -1168,8 +1170,7 @@ def _opt_sym(name, argtypes=(), restype=None): err_getter = _opt_sym("aiter_get_last_error", restype=ctypes.c_char_p) err_clear = _opt_sym("aiter_clear_last_error") - hints = typing.get_type_hints(func) - ret_hint = hints.get("return") + ret_hint = _hints.get("return") ctypes_data_return = ctypes_status_mode and ret_hint is int if ctypes_status_mode: @@ -1183,8 +1184,8 @@ def _opt_sym(name, argtypes=(), restype=None): argtypes = [] has_tensor = False - for pname in inspect.signature(func).parameters: - hint = hints.get(pname) + for pname in _sig.parameters: + hint = _hints.get(pname) origin = typing.get_origin(hint) type_args = typing.get_args(hint) if hint is torch.Tensor: @@ -1252,20 +1253,17 @@ def _check_args_before_convert(bound_args, hints): elif hint is str: if not isinstance(value, str): raise TypeError( - f"{fc_name}: '{pname}' expects str, " - f"got {type(value).__name__}" + f"{fc_name}: '{pname}' expects str, got {type(value).__name__}" ) elif hint is bool: if not isinstance(value, (bool, int)): raise TypeError( - f"{fc_name}: '{pname}' expects bool, " - f"got {type(value).__name__}" + f"{fc_name}: '{pname}' expects bool, got {type(value).__name__}" ) elif hint is int: if not isinstance(value, int): raise TypeError( - f"{fc_name}: '{pname}' expects int, " - f"got {type(value).__name__}" + f"{fc_name}: '{pname}' expects int, got {type(value).__name__}" ) elif hint is float: if not isinstance(value, (float, int)): @@ -1287,13 +1285,11 @@ def caller(*args, **kwargs): from ..test_common import log_args log_args(func, *args, **kwargs) - sig = inspect.signature(func) - bound = sig.bind(*args, **kwargs) + bound = _sig.bind(*args, **kwargs) bound.apply_defaults() - hints = typing.get_type_hints(func) if not _arg_checked: - _check_args_before_convert(bound.arguments, hints) + _check_args_before_convert(bound.arguments, _hints) _arg_checked = True c_args = [] @@ -1301,7 +1297,7 @@ def caller(*args, **kwargs): tensor_device = None for pname, value in bound.arguments.items(): - hint = hints.get(pname) + hint = _hints.get(pname) origin = typing.get_origin(hint) type_args = typing.get_args(hint)