Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,8 @@ def _ctypes_call(func, fc_name, md_name):

_cache = {}
_arg_checked = False
_sig = inspect.signature(func)
_hints = typing.get_type_hints(func)

def _ensure_loaded():
if _cache:
Expand Down Expand Up @@ -1168,8 +1170,7 @@ def _opt_sym(name, argtypes=(), restype=None):
err_getter = _opt_sym("aiter_get_last_error", restype=ctypes.c_char_p)
err_clear = _opt_sym("aiter_clear_last_error")

hints = typing.get_type_hints(func)
ret_hint = hints.get("return")
ret_hint = _hints.get("return")
ctypes_data_return = ctypes_status_mode and ret_hint is int

if ctypes_status_mode:
Expand All @@ -1183,8 +1184,8 @@ def _opt_sym(name, argtypes=(), restype=None):

argtypes = []
has_tensor = False
for pname in inspect.signature(func).parameters:
hint = hints.get(pname)
for pname in _sig.parameters:
hint = _hints.get(pname)
origin = typing.get_origin(hint)
type_args = typing.get_args(hint)
if hint is torch.Tensor:
Expand Down Expand Up @@ -1287,21 +1288,19 @@ def caller(*args, **kwargs):
from ..test_common import log_args

log_args(func, *args, **kwargs)
sig = inspect.signature(func)
bound = sig.bind(*args, **kwargs)
bound = _sig.bind(*args, **kwargs)
bound.apply_defaults()
hints = typing.get_type_hints(func)

if not _arg_checked:
_check_args_before_convert(bound.arguments, hints)
_check_args_before_convert(bound.arguments, _hints)
_arg_checked = True

c_args = []
aiter_refs = []
tensor_device = None

for pname, value in bound.arguments.items():
hint = hints.get(pname)
hint = _hints.get(pname)
origin = typing.get_origin(hint)
type_args = typing.get_args(hint)

Expand Down
Loading