Source code for pons.local_provider
"""
PyEVM-based provider for tests.
Requires the dependencies from the ``local-provider`` feature.
"""
import itertools
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from copy import deepcopy
from typing import Any
from alysis import EVMVersion, Node, RPCNode
from eth_account import Account
from ethereum_rpc import Amount, RPCError
from ._provider import RPC_JSON, Provider, ProviderError, ProviderSession
from ._signer import AccountSigner
__all__ = ["LocalProvider", "SnapshotID"]
[docs]
class SnapshotID:
"""An ID of a snapshot in a :py:class:`LocalProvider`."""
def __init__(self, id_: int):
self.id_ = id_
[docs]
class LocalProvider(Provider):
"""A provider maintaining its own chain state, useful for tests."""
root: AccountSigner
"""The signer for the pre-created account."""
def __init__(
self,
*,
root_balance: Amount,
chain_id: int = 1,
evm_version: EVMVersion = EVMVersion.PRAGUE,
):
self._local_node = Node(
root_balance_wei=root_balance.as_wei(), chain_id=chain_id, evm_version=evm_version
)
self._rpc_node = RPCNode(self._local_node)
self.root = AccountSigner(Account.from_key(self._local_node.root_private_key))
self._default_address = self.root.address
self._snapshot_counter = itertools.count()
self._snapshots: dict[int, Node] = {}
[docs]
def disable_auto_mine_transactions(self) -> None:
"""Disable mining a new block after each transaction."""
self._local_node.disable_auto_mine_transactions()
[docs]
def enable_auto_mine_transactions(self) -> None:
"""
Enable mining a new block after each transaction.
This is the default behavior.
"""
self._local_node.enable_auto_mine_transactions()
[docs]
def take_snapshot(self) -> SnapshotID:
"""Creates a snapshot of the chain state internally and returns its ID."""
snapshot_id = next(self._snapshot_counter)
self._snapshots[snapshot_id] = deepcopy(self._local_node)
return SnapshotID(snapshot_id)
[docs]
def revert_to_snapshot(self, snapshot_id: SnapshotID) -> None:
"""Restores the chain state to the snapshot with the given ID."""
self._local_node = self._snapshots[snapshot_id.id_]
self._rpc_node = RPCNode(self._local_node)
def rpc(self, method: str, *args: Any) -> RPC_JSON: # noqa: D102
try:
return self._rpc_node.rpc(method, *args)
except RPCError as exc:
raise ProviderError(exc) from exc
@asynccontextmanager
async def session(self) -> AsyncIterator["LocalProviderSession"]: # noqa: D102
yield LocalProviderSession(self)
class LocalProviderSession(ProviderSession):
def __init__(self, provider: LocalProvider):
self._provider = provider
async def rpc(self, method: str, *args: RPC_JSON) -> RPC_JSON:
return self._provider.rpc(method, *args)