diff --git a/newsfragments/3668.performance.rst b/newsfragments/3668.performance.rst new file mode 100644 index 0000000000..f002a0505a --- /dev/null +++ b/newsfragments/3668.performance.rst @@ -0,0 +1 @@ +Optimize web3._utils.decorators.reject_recursive_repeats diff --git a/newsfragments/3671.performance.rst b/newsfragments/3671.performance.rst new file mode 100644 index 0000000000..13049a8c1e --- /dev/null +++ b/newsfragments/3671.performance.rst @@ -0,0 +1 @@ +optimize web3._utils.rpc_abi.apply_abi_formatters_to_dict diff --git a/tests/core/method-class/test_method.py b/tests/core/method-class/test_method.py index 0d82a6b602..88766d9d09 100644 --- a/tests/core/method-class/test_method.py +++ b/tests/core/method-class/test_method.py @@ -17,7 +17,6 @@ ) from web3.method import ( Method, - _apply_request_formatters, default_root_munger, ) from web3.module import ( @@ -61,11 +60,7 @@ def test_get_formatters_default_formatter_for_falsy_config(): default_result_formatters = method.result_formatters( method.method_selector_fn(), "some module" ) - assert _apply_request_formatters(["a", "b", "c"], default_request_formatters) == ( - "a", - "b", - "c", - ) + assert default_request_formatters(["a", "b", "c"]) == ("a", "b", "c") assert apply_result_formatters(default_result_formatters, ["a", "b", "c"]) == [ "a", "b", diff --git a/web3/_utils/decorators.py b/web3/_utils/decorators.py index 346ac5b045..7115ff7a7b 100644 --- a/web3/_utils/decorators.py +++ b/web3/_utils/decorators.py @@ -3,6 +3,8 @@ from typing import ( Any, Callable, + Set, + Tuple, TypeVar, cast, ) @@ -20,21 +22,22 @@ def reject_recursive_repeats(to_wrap: Callable[..., Any]) -> Callable[..., Any]: Prevent simple cycles by returning None when called recursively with same instance """ # types ignored b/c dynamically set attribute - to_wrap.__already_called = {} # type: ignore + already_called: Set[Tuple[int, ...]] = set() + to_wrap.__already_called = already_called # type: ignore + + add_call = already_called.add + remove_call = already_called.remove @functools.wraps(to_wrap) def wrapped(*args: Any) -> Any: - arg_instances = tuple(map(id, args)) - thread_id = threading.get_ident() - thread_local_args = (thread_id,) + arg_instances - if thread_local_args in to_wrap.__already_called: # type: ignore + thread_local_args = (threading.get_ident(), *map(id, args)) + if thread_local_args in already_called: raise Web3ValueError(f"Recursively called {to_wrap} with {args!r}") - to_wrap.__already_called[thread_local_args] = True # type: ignore + add_call(thread_local_args) try: - wrapped_val = to_wrap(*args) + return to_wrap(*args) finally: - del to_wrap.__already_called[thread_local_args] # type: ignore - return wrapped_val + remove_call(thread_local_args) return wrapped diff --git a/web3/_utils/method_formatters.py b/web3/_utils/method_formatters.py index 71f27e1970..2a69a22bd9 100644 --- a/web3/_utils/method_formatters.py +++ b/web3/_utils/method_formatters.py @@ -7,6 +7,7 @@ Collection, Dict, Iterable, + Iterator, NoReturn, Tuple, TypeVar, @@ -35,7 +36,6 @@ is_string, to_checksum_address, to_list, - to_tuple, ) from eth_utils.toolz import ( complement, @@ -101,6 +101,7 @@ BlockIdentifier, Formatters, RPCEndpoint, + RPCResponse, SimulateV1Payload, StateOverrideParams, TReturn, @@ -115,6 +116,11 @@ TValue = TypeVar("TValue") +CachedFormatters = Dict[ + Union[RPCEndpoint, Callable[..., RPCEndpoint]], + Callable[[RPCResponse], Any], +] + def bytes_to_ascii(value: bytes) -> str: return codecs.decode(value, "ascii") @@ -600,19 +606,15 @@ def storage_key_to_hexstr(value: Union[bytes, int, str]) -> HexStr: ) block_result_formatters_copy = BLOCK_RESULT_FORMATTERS.copy() -block_result_formatters_copy.update( - { - "calls": apply_list_to_array_formatter( - type_aware_apply_formatters_to_dict( - { - "returnData": HexBytes, - "logs": apply_list_to_array_formatter(log_entry_formatter), - "gasUsed": to_integer_if_hex, - "status": to_integer_if_hex, - } - ) - ) - } +block_result_formatters_copy["calls"] = apply_list_to_array_formatter( + type_aware_apply_formatters_to_dict( + { + "returnData": HexBytes, + "logs": apply_list_to_array_formatter(log_entry_formatter), + "gasUsed": to_integer_if_hex, + "status": to_integer_if_hex, + } + ) ) simulate_v1_result_formatter = apply_formatter_if( is_not_null, @@ -1046,11 +1048,10 @@ def subscription_formatter(value: Any) -> Union[HexBytes, HexStr, Dict[str, Any] } -@to_tuple def combine_formatters( formatter_maps: Collection[Dict[RPCEndpoint, Callable[..., TReturn]]], method_name: RPCEndpoint, -) -> Iterable[Callable[..., TReturn]]: +) -> Iterator[Callable[..., TReturn]]: for formatter_map in formatter_maps: if method_name in formatter_map: yield formatter_map[method_name] @@ -1069,7 +1070,7 @@ def get_request_formatters( PYTHONIC_REQUEST_FORMATTERS, ) formatters = combine_formatters(request_formatter_maps, method_name) - return compose(*formatters) + return compose(tuple, *formatters) def raise_block_not_found(params: Tuple[BlockIdentifier, bool]) -> NoReturn: @@ -1188,12 +1189,11 @@ def filter_wrapper( } -@to_tuple def apply_module_to_formatters( - formatters: Tuple[Callable[..., TReturn]], + formatters: Iterable[Callable[..., TReturn]], module: "Module", method_name: Union[RPCEndpoint, Callable[..., RPCEndpoint]], -) -> Iterable[Callable[..., TReturn]]: +) -> Iterator[Callable[..., TReturn]]: for f in formatters: yield partial(f, module, method_name) @@ -1201,7 +1201,7 @@ def apply_module_to_formatters( def get_result_formatters( method_name: Union[RPCEndpoint, Callable[..., RPCEndpoint]], module: "Module", -) -> Dict[str, Callable[..., Any]]: +) -> Callable[[RPCResponse], Any]: formatters = combine_formatters((PYTHONIC_RESULT_FORMATTERS,), method_name) formatters_requiring_module = combine_formatters( (FILTER_RESULT_FORMATTERS,), method_name @@ -1212,19 +1212,28 @@ def get_result_formatters( return compose(*partial_formatters, *formatters) +_error_formatters: CachedFormatters = {} + def get_error_formatters( method_name: Union[RPCEndpoint, Callable[..., RPCEndpoint]] -) -> Callable[..., Any]: +) -> Callable[[RPCResponse], Any]: # Note error formatters work on the full response dict - error_formatter_maps = (ERROR_FORMATTERS,) - formatters = combine_formatters(error_formatter_maps, method_name) + formatters = _error_formatters.get(method_name) + if formatters is None: + formatters = _error_formatters[method_name] = compose( + *combine_formatters((ERROR_FORMATTERS,), method_name) + ) + return formatters - return compose(*formatters) +_null_result_formatters: CachedFormatters = {} def get_null_result_formatters( method_name: Union[RPCEndpoint, Callable[..., RPCEndpoint]] -) -> Callable[..., Any]: - formatters = combine_formatters((NULL_RESULT_FORMATTERS,), method_name) - - return compose(*formatters) +) -> Callable[[RPCResponse], Any]: + formatters = _null_result_formatters.get(method_name) + if formatters is None: + formatters = _null_result_formatters[method_name] = compose( + *combine_formatters((NULL_RESULT_FORMATTERS,), method_name) + ) + return formatters diff --git a/web3/_utils/rpc_abi.py b/web3/_utils/rpc_abi.py index c401309c55..8220a4a428 100644 --- a/web3/_utils/rpc_abi.py +++ b/web3/_utils/rpc_abi.py @@ -218,8 +218,9 @@ def apply_abi_formatters_to_dict( [abi_dict[field] for field in fields], [data[field] for field in fields], ) - formatted_dict = dict(zip(fields, formatted_values)) - return dict(data, **formatted_dict) + formatted_dict = data.copy() + formatted_dict.update(zip(fields, formatted_values)) + return formatted_dict @to_dict diff --git a/web3/_utils/utility_methods.py b/web3/_utils/utility_methods.py index 66ad373ff2..45051655d8 100644 --- a/web3/_utils/utility_methods.py +++ b/web3/_utils/utility_methods.py @@ -1,7 +1,7 @@ from typing import ( Any, - Dict, Iterable, + Mapping, Set, Union, ) @@ -13,7 +13,7 @@ def all_in_dict( - values: Iterable[Any], d: Union[Dict[Any, Any], TxData, TxParams] + values: Iterable[Any], d: Union[Mapping[Any, Any], TxData, TxParams] ) -> bool: """ Returns a bool based on whether ALL of the provided values exist @@ -24,11 +24,12 @@ def all_in_dict( :return: True if ALL values exist in keys; False if NOT ALL values exist in keys """ - return all(_ in dict(d) for _ in values) + d = dict(d) + return all(_ in d for _ in values) def any_in_dict( - values: Iterable[Any], d: Union[Dict[Any, Any], TxData, TxParams] + values: Iterable[Any], d: Union[Mapping[Any, Any], TxData, TxParams] ) -> bool: """ Returns a bool based on whether ANY of the provided values exist @@ -39,11 +40,12 @@ def any_in_dict( :return: True if ANY value exists in keys; False if NONE of the values exist in keys """ - return any(_ in dict(d) for _ in values) + d = dict(d) + return any(_ in d for _ in values) def none_in_dict( - values: Iterable[Any], d: Union[Dict[Any, Any], TxData, TxParams] + values: Iterable[Any], d: Union[Mapping[Any, Any], TxData, TxParams] ) -> bool: """ Returns a bool based on whether NONE of the provided values exist diff --git a/web3/method.py b/web3/method.py index f67828f964..645901a670 100644 --- a/web3/method.py +++ b/web3/method.py @@ -14,13 +14,6 @@ ) import warnings -from eth_utils.curried import ( - to_tuple, -) -from eth_utils.toolz import ( - pipe, -) - from web3._utils.batching import ( RPC_METHODS_UNSUPPORTED_DURING_BATCH, ) @@ -56,16 +49,6 @@ Munger = Callable[..., Any] -@to_tuple -def _apply_request_formatters( - params: Any, request_formatters: Dict[RPCEndpoint, Callable[..., TReturn]] -) -> Tuple[Any, ...]: - if request_formatters: - formatted_params = pipe(params, request_formatters) - return formatted_params - return params - - def _set_mungers( mungers: Optional[Sequence[Munger]], is_property: bool ) -> Sequence[Any]: @@ -232,10 +215,12 @@ def process_params( get_error_formatters(method), self.null_result_formatters(method), ) - request = ( - method, - _apply_request_formatters(params, self.request_formatters(method)), - ) + + if request_formatters := self.request_formatters(method): + params = request_formatters(params) + + request = method, params + return request, response_formatters