Source code for pons.local_provider

"""
PyEVM-based provider for tests.
Requires the dependencies from the ``local-provider`` feature.
"""

import itertools
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from copy import deepcopy
from typing import Any

from alysis import EVMVersion, Node, RPCNode
from eth_account import Account
from ethereum_rpc import Amount, RPCError

from ._provider import RPC_JSON, Provider, ProviderError, ProviderSession
from ._signer import AccountSigner

__all__ = ["LocalProvider", "SnapshotID"]


[docs] class SnapshotID: """An ID of a snapshot in a :py:class:`LocalProvider`.""" def __init__(self, id_: int): self.id_ = id_
[docs] class LocalProvider(Provider): """A provider maintaining its own chain state, useful for tests.""" root: AccountSigner """The signer for the pre-created account.""" def __init__( self, *, root_balance: Amount, chain_id: int = 1, evm_version: EVMVersion = EVMVersion.PRAGUE, ): self._local_node = Node( root_balance_wei=root_balance.as_wei(), chain_id=chain_id, evm_version=evm_version ) self._rpc_node = RPCNode(self._local_node) self.root = AccountSigner(Account.from_key(self._local_node.root_private_key)) self._default_address = self.root.address self._snapshot_counter = itertools.count() self._snapshots: dict[int, Node] = {}
[docs] def disable_auto_mine_transactions(self) -> None: """Disable mining a new block after each transaction.""" self._local_node.disable_auto_mine_transactions()
[docs] def enable_auto_mine_transactions(self) -> None: """ Enable mining a new block after each transaction. This is the default behavior. """ self._local_node.enable_auto_mine_transactions()
[docs] def take_snapshot(self) -> SnapshotID: """Creates a snapshot of the chain state internally and returns its ID.""" snapshot_id = next(self._snapshot_counter) self._snapshots[snapshot_id] = deepcopy(self._local_node) return SnapshotID(snapshot_id)
[docs] def revert_to_snapshot(self, snapshot_id: SnapshotID) -> None: """Restores the chain state to the snapshot with the given ID.""" self._local_node = self._snapshots[snapshot_id.id_] self._rpc_node = RPCNode(self._local_node)
def rpc(self, method: str, *args: Any) -> RPC_JSON: # noqa: D102 try: return self._rpc_node.rpc(method, *args) except RPCError as exc: raise ProviderError(exc) from exc @asynccontextmanager async def session(self) -> AsyncIterator["LocalProviderSession"]: # noqa: D102 yield LocalProviderSession(self)
class LocalProviderSession(ProviderSession): def __init__(self, provider: LocalProvider): self._provider = provider async def rpc(self, method: str, *args: RPC_JSON) -> RPC_JSON: return self._provider.rpc(method, *args)