Source code for pons._abi_types

# We need to have some module-private members in `Type`.
# ruff: noqa: SLF001

import re
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from types import EllipsisType
from typing import Any, cast

import eth_abi
from eth_abi.exceptions import DecodingError
from ethereum_rpc import Address, keccak

ABI_JSON = None | bool | int | float | str | Sequence["ABI_JSON"] | Mapping[str, "ABI_JSON"]
"""Values serializable to JSON."""

# Maximum bits in an `int` or `uint` type in Solidity.
MAX_INTEGER_BITS = 256

# Maximum size of a `bytes` type in Solidity.
MAX_BYTES_SIZE = 32


[docs] class ABIDecodingError(Exception): """Raised on an error when decoding a value in an Eth ABI encoded bytestring."""
ABIType = int | str | bytes | bool | list["ABIType"] """ Represents the argument type that can be received from ``eth_abi.decode()``, or passed to ``eth_abi.encode()``. """
[docs] class Type(ABC): """The base type for Solidity types.""" @property @abstractmethod def canonical_form(self) -> str: """Returns the type as a string in the canonical form (for ``eth_abi`` consumption).""" ... @abstractmethod def _normalize(self, val: Any) -> ABIType: """ Checks and possibly normalizes the value making it ready to be passed to ``encode()`` for encoding. """ ... @abstractmethod def _denormalize(self, val: ABIType) -> Any: """Checks the result of ``decode()`` and wraps it in a specific type, if applicable.""" ... def encode(self, val: Any) -> bytes: """Encodes the given value in the contract ABI format.""" return eth_abi.encode([self.canonical_form], [val]) def encode_to_topic(self, val: Any) -> bytes: """Encodes the given value as an event topic.""" # EVM uses a simpler encoding scheme for encoding values into event topics # because objects of reference types are just hashed, # and there is no need to unpack them later # (basically, all values are just concatenated without any lentgh labels). # Therefore we have to provide these methods # and cannot just use the functions from ``eth_abi``. # Before doing anything, normalize the value, # this will ensure the constituent values are actually valid. return self._encode_to_topic_outer(self._normalize(val)) def _encode_to_topic_outer(self, val: Any) -> bytes: """Encodes a value of the outer indexed type.""" # By default it's just the encoding of the value type. # May be overridden. return self.encode(val) def _encode_to_topic_inner(self, val: Any) -> bytes: """Encodes a value contained within an indexed array or struct.""" # By default it's just the encoding of the value type. # May be overridden. return self.encode(val) def decode_from_topic(self, val: bytes) -> Any | None: """ Decodes an encoded topic. Returns ``None`` if the decoding is impossible (that is, the original value was hashed). """ # This method does not have inner/outer division, since all reference types are hashed, # and there's no need to go recursively into structs/arrays - we won't be able to recover # the values anyway. # By default it's just the decoding of the value type. # May be overridden. return self._denormalize(eth_abi.decode([self.canonical_form], val)[0]) def __str__(self) -> str: return self.canonical_form def __getitem__(self, array_size: int | EllipsisType) -> "Array": # In Py3.10 they added EllipsisType which would work better here. # For now, relying on the documentation. if isinstance(array_size, int): return Array(self, array_size) if array_size == ...: return Array(self, None) raise TypeError(f"Invalid array size specifier type: {type(array_size).__name__}")
[docs] class UInt(Type): """Corresponds to the Solidity ``uint<bits>`` type.""" def __init__(self, bits: int): if bits <= 0 or bits > MAX_INTEGER_BITS or bits % 8 != 0: raise ValueError(f"Incorrect `uint` bit size: {bits}") self._bits = bits @property def canonical_form(self) -> str: return f"uint{self._bits}" def _check_val(self, val: Any) -> int: # `bool` is a subclass of `int`, but we would rather be more strict # and prevent possible bugs. if not isinstance(val, int) or isinstance(val, bool): raise TypeError( f"`{self.canonical_form}` must correspond to an integer, got {type(val).__name__}" ) if val < 0: raise ValueError( f"`{self.canonical_form}` must correspond to a non-negative integer, got {val}" ) if val >> self._bits != 0: raise ValueError( f"`{self.canonical_form}` must correspond to an unsigned integer " f"under {self._bits} bits, got {val}" ) return int(val) def _normalize(self, val: Any) -> int: return self._check_val(val) def _denormalize(self, val: ABIType) -> int: return self._check_val(val) def __eq__(self, other: object) -> bool: return isinstance(other, UInt) and self._bits == other._bits def __hash__(self) -> int: return hash((type(self), self._bits))
[docs] class Int(Type): """Corresponds to the Solidity ``int<bits>`` type.""" def __init__(self, bits: int): if bits <= 0 or bits > MAX_INTEGER_BITS or bits % 8 != 0: raise ValueError(f"Incorrect `int` bit size: {bits}") self._bits = bits @property def canonical_form(self) -> str: return f"int{self._bits}" def _check_val(self, val: Any) -> int: # `bool` is a subclass of `int`, but we would rather be more strict # and prevent possible bugs. if not isinstance(val, int) or isinstance(val, bool): raise TypeError( f"`{self.canonical_form}` must correspond to an integer, got {type(val).__name__}" ) if (val + (1 << (self._bits - 1))) >> self._bits != 0: raise ValueError( f"`{self.canonical_form}` must correspond to a signed integer " f"under {self._bits} bits, got {val}" ) return int(val) def _normalize(self, val: Any) -> int: return self._check_val(val) def _denormalize(self, val: ABIType) -> int: return self._check_val(val) def __eq__(self, other: object) -> bool: return isinstance(other, Int) and self._bits == other._bits def __hash__(self) -> int: return hash((type(self), self._bits))
[docs] class Bytes(Type): """Corresponds to the Solidity ``bytes<size>`` type.""" def __init__(self, size: None | int = None): if size is not None and (size <= 0 or size > MAX_BYTES_SIZE): raise ValueError(f"Incorrect `bytes` size: {size}") self._size = size @property def canonical_form(self) -> str: return f"bytes{self._size if self._size else ''}" def _check_val(self, val: Any) -> bytes: if not isinstance(val, bytes): raise TypeError( f"`{self.canonical_form}` must correspond to a bytestring, got {type(val).__name__}" ) if self._size is not None and len(val) != self._size: raise ValueError(f"Expected {self._size} bytes, got {len(val)}") return val def _normalize(self, val: Any) -> bytes: return self._check_val(val) def _denormalize(self, val: ABIType) -> bytes: return self._check_val(val) def _encode_to_topic_outer(self, val: bytes) -> bytes: if self._size is None: # Dynamic `bytes` is a reference type and is therefore hashed. return keccak(val) # Sized `bytes` is a value type, falls back to the base implementation. return super()._encode_to_topic_outer(val) def _encode_to_topic_inner(self, val: bytes) -> bytes: if self._size is None: # Dynamic `bytes` is padded to a multiple of 32 bytes. padding_len = (32 - len(val)) % 32 return val + b"\x00" * padding_len # Sized `bytes` is a value type, falls back to the base implementation. return super()._encode_to_topic_inner(val) def decode_from_topic(self, val: bytes) -> None | bytes: if self._size is None: # Cannot recover a hashed value. return None return super().decode_from_topic(val) def __eq__(self, other: object) -> bool: return isinstance(other, Bytes) and self._size == other._size def __hash__(self) -> int: return hash((type(self), self._size))
[docs] class AddressType(Type): """ Corresponds to the Solidity ``address`` type. Not to be confused with :py:class:`ethereum_rpc.Address` which represents an address value. """ @property def canonical_form(self) -> str: return "address" def _normalize(self, val: Any) -> str: if not isinstance(val, Address): raise TypeError( f"`address` must correspond to an `Address`-type value, got {type(val).__name__}" ) return val.checksum def _denormalize(self, val: ABIType) -> Address: if not isinstance(val, str): raise TypeError(f"Expected a string to convert to `Address`, got {type(val).__name__}") return Address.from_hex(val) def __eq__(self, other: object) -> bool: return isinstance(other, AddressType) def __hash__(self) -> int: return hash(type(self))
[docs] class String(Type): """Corresponds to the Solidity ``string`` type.""" @property def canonical_form(self) -> str: return "string" def _check_val(self, val: Any) -> str: if not isinstance(val, str): raise TypeError( f"`string` must correspond to a `str`-type value, got {type(val).__name__}" ) return val def _normalize(self, val: Any) -> str: return self._check_val(val) def _denormalize(self, val: ABIType) -> str: return self._check_val(val) def _encode_to_topic_outer(self, val: str) -> bytes: # `string` is encoded and treated as dynamic `bytes` return Bytes()._encode_to_topic_outer(val.encode()) def _encode_to_topic_inner(self, val: str) -> bytes: # `string` is encoded and treated as dynamic `bytes` return Bytes()._encode_to_topic_inner(val.encode()) def decode_from_topic(self, _val: bytes) -> None: # Dynamic `string` is hashed, so the value cannot be recovered. return None def __eq__(self, other: object) -> bool: return isinstance(other, String) def __hash__(self) -> int: return hash(type(self))
[docs] class Bool(Type): """Corresponds to the Solidity ``bool`` type.""" @property def canonical_form(self) -> str: return "bool" def _check_val(self, val: Any) -> bool: if not isinstance(val, bool): raise TypeError( f"`bool` must correspond to a `bool`-type value, got {type(val).__name__}" ) return val def _normalize(self, val: Any) -> bool: return self._check_val(val) def _denormalize(self, val: ABIType) -> bool: return self._check_val(val) def __eq__(self, other: object) -> bool: return isinstance(other, Bool) def __hash__(self) -> int: return hash(type(self))
class Array(Type): """Corresponds to the Solidity array (``[<size>]``) type.""" def __init__(self, element_type: Type, size: None | int = None): self._element_type = element_type self._size = size @cached_property def canonical_form(self) -> str: return ( self._element_type.canonical_form + "[" + (str(self._size) if self._size else "") + "]" ) def _check_val(self, val: Any) -> Sequence[Any]: if not isinstance(val, Sequence): raise TypeError(f"Expected an iterable, got {type(val).__name__}") if self._size is not None and len(val) != self._size: raise ValueError(f"Expected {self._size} elements, got {len(val)}") return val def _normalize(self, val: Any) -> list[ABIType]: return [self._element_type._normalize(item) for item in self._check_val(val)] def _denormalize(self, val: ABIType) -> list[ABIType]: return [self._element_type._denormalize(item) for item in self._check_val(val)] def _encode_to_topic_outer(self, val: Any) -> bytes: return keccak(self._encode_to_topic_inner(val)) def _encode_to_topic_inner(self, val: Any) -> bytes: return b"".join(self._element_type._encode_to_topic_inner(elem) for elem in val) def decode_from_topic(self, _val: Any) -> None: return None def __eq__(self, other: object) -> bool: return ( isinstance(other, Array) and self._element_type == other._element_type and self._size == other._size ) def __hash__(self) -> int: return hash((type(self), self._element_type, self._size))
[docs] class Struct(Type): """Corresponds to the Solidity struct type.""" def __init__(self, fields: Mapping[str, Type]): self._fields = fields @cached_property def canonical_form(self) -> str: return "(" + ",".join(field.canonical_form for field in self._fields.values()) + ")" def _check_val(self, val: Any) -> Sequence[Any]: if not isinstance(val, Sequence): raise TypeError(f"Expected an iterable, got {type(val).__name__}") if len(val) != len(self._fields): raise ValueError(f"Expected {len(self._fields)} elements, got {len(val)}") return val def _normalize(self, val: Any) -> list[ABIType]: if isinstance(val, Mapping): if val.keys() != self._fields.keys(): raise ValueError( f"Expected fields {list(self._fields.keys())}, got {list(val.keys())}" ) return [tp._normalize(val[name]) for name, tp in self._fields.items()] return [ tp._normalize(item) for item, tp in zip(self._check_val(val), self._fields.values(), strict=True) ] def _denormalize(self, val: ABIType) -> dict[str, ABIType]: return { name: tp._denormalize(item) for item, (name, tp) in zip(self._check_val(val), self._fields.items(), strict=True) } def _encode_to_topic_outer(self, val: Any) -> bytes: return keccak(self._encode_to_topic_inner(val)) def _encode_to_topic_inner(self, val: Any) -> bytes: return b"".join( tp._encode_to_topic_inner(elem) for elem, tp in zip(val, self._fields.values(), strict=True) ) def decode_from_topic(self, _val: Any) -> None: return None def __str__(self) -> str: # Overriding the `Type`'s implementation because we want to show the field names too return "(" + ", ".join(str(tp) + " " + str(name) for name, tp in self._fields.items()) + ")" def __eq__(self, other: object) -> bool: return ( isinstance(other, Struct) and self._fields == other._fields # structs with the same fields but in different order are not equal and list(self._fields) == list(other._fields) ) def __hash__(self) -> int: return hash((type(self), tuple(self._fields.items())))
_UINT_RE = re.compile(r"uint(\d+)") _INT_RE = re.compile(r"int(\d+)") _BYTES_RE = re.compile(r"bytes(\d+)?") _NO_PARAMS = { "address": AddressType(), "string": String(), "bool": Bool(), } def type_from_abi_string(abi_string: str) -> Type: if match := _UINT_RE.match(abi_string): return UInt(int(match.group(1))) if match := _INT_RE.match(abi_string): return Int(int(match.group(1))) if match := _BYTES_RE.match(abi_string): size = match.group(1) return Bytes(int(size) if size else None) if abi_string in _NO_PARAMS: return _NO_PARAMS[abi_string] raise ValueError(f"Unknown type: {abi_string}") def dispatch_type(abi_entry: ABI_JSON) -> Type: # TODO (#83): use proper validation abi_entry_typed = cast("Mapping[str, Any]", abi_entry) type_str = abi_entry_typed["type"] match = re.match(r"^([\w\d\[\]]*?)(\[(\d+)?\])?$", type_str) if not match: raise ValueError(f"Incorrect type format: {type_str}") element_type_name = match.group(1) is_array = match.group(2) array_size = match.group(3) if array_size is not None: array_size = int(array_size) if is_array: element_entry = dict(abi_entry_typed) element_entry["type"] = element_type_name element_type = dispatch_type(element_entry) return Array(element_type, array_size) if element_type_name == "tuple": fields = {} for component in abi_entry_typed["components"]: fields[component["name"]] = dispatch_type(component) return Struct(fields) return type_from_abi_string(element_type_name) def dispatch_types(abi_entry: ABI_JSON) -> list[Type] | dict[str, Type]: # TODO (#83): use proper validation abi_entry_typed = cast("Iterable[dict[str, Any]]", abi_entry) names = [entry["name"] for entry in abi_entry_typed] # Unnamed arguments; treat as positional arguments if names and all(not name for name in names): return [dispatch_type(entry) for entry in abi_entry_typed] if any(not name for name in names): raise ValueError("Arguments must be either all named or all unnamed") # Since we are returning a dictionary, need to be sure we don't silently merge entries if len(names) != len(set(names)): raise ValueError("All ABI entries must have distinct names") return {entry["name"]: dispatch_type(entry) for entry in abi_entry_typed} def encode_args(*types_and_args: tuple[Type, Any]) -> bytes: if types_and_args: types, args = zip(*types_and_args, strict=True) else: types, args = (), () return eth_abi.encode( [tp.canonical_form for tp in types], tuple(tp._normalize(arg) for tp, arg in zip(types, args, strict=True)), ) def decode_args(types: Iterable[Type], data: bytes) -> tuple[ABIType, ...]: canonical_types = [tp.canonical_form for tp in types] try: values = eth_abi.decode(canonical_types, data) except DecodingError as exc: # wrap possible `eth_abi` errors signature = "(" + ",".join(canonical_types) + ")" message = ( f"Could not decode the return value with the expected signature {signature}: {exc}" ) raise ABIDecodingError(message) from exc return tuple(tp._denormalize(value) for tp, value in zip(types, values, strict=True))