Source code for pons._fallback_provider

import threading
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterable, Mapping
from collections.abc import Set as AbstractSet
from contextlib import AsyncExitStack, asynccontextmanager

from ._provider import (
    RPC_JSON,
    Provider,
    ProviderError,
    ProviderPath,
    ProviderSession,
)


[docs] class FallbackStrategy(ABC): """An abstract class defining a fallback strategy for multiple providers."""
[docs] @abstractmethod def get_provider_order(self) -> list[str]: """ Returns the suggested order of providers to query, based on the accumulated data. This method is called once on every high-level request to the provider. """
[docs] class FallbackStrategyFactory(ABC): """ An abstract class defining a fallback strategy factory for multiple providers. This will be called in ``FallbackProvider`` to create an actual strategy object (which may be mutated). """
[docs] @abstractmethod def make_strategy(self, provider_ids: AbstractSet[str]) -> FallbackStrategy: """Returns a strategy object."""
class CycleFallbackStrategy(FallbackStrategy): def __init__(self, weights: dict[str, int]): self._weights = weights self._provider_ids = list(weights.keys()) self._counter = 0 def get_provider_order(self) -> list[str]: if self._counter == self._weights[self._provider_ids[0]]: self._counter = 0 self._provider_ids = [*self._provider_ids[1:], self._provider_ids[0]] self._counter += 1 return self._provider_ids
[docs] class CycleFallback(FallbackStrategyFactory): """ Creates a strategy where the providers are cycled such that the number of times a given provider is first in the priority list is equal to the corresponding entry in ``weights`` (the length of which should match the number of providers). If ``weights`` is not given, a list of ``1`` will be used. """ def __init__(self, weights: None | Iterable[int] = None): self._weights: None | list[int] if weights: self._weights = list(weights) else: self._weights = None def make_strategy(self, provider_ids: AbstractSet[str]) -> CycleFallbackStrategy: num_providers = len(provider_ids) weights = self._weights or [1] * num_providers if len(weights) != num_providers: raise ValueError( f"Length of the weights ({len(weights)}) " f"inconsistent with the number of providers ({num_providers})" ) return CycleFallbackStrategy(dict(zip(provider_ids, weights, strict=True)))
class PriorityFallbackStrategy(FallbackStrategy): """ Creates a strategy where the providers are queried in the order they were given to ``FallbackProvider``, until a successful response is received. """ def __init__(self, provider_ids: AbstractSet[str]): self._provider_ids = list(provider_ids) def get_provider_order(self) -> list[str]: return self._provider_ids
[docs] class PriorityFallback(FallbackStrategyFactory): def make_strategy(self, provider_ids: AbstractSet[str]) -> PriorityFallbackStrategy: return PriorityFallbackStrategy(provider_ids)
[docs] class FallbackProvider(Provider): """ A provider that encompasses several providers and for every request tries every one of them until the request is successful. The order is chosen according to the given strategy. If ``same_provider`` is ``True``, the given providers are treated as endpoints pointing to the same physical provider, for the purpose of stateful requests (e.g. filter creation). If ``strategy`` is ``None``, an instance of :py:class:`PriorityFallback` is used. If a request attempt results in an error for which ``use_fallback`` returns ``True``, the next provider based on the chosen strategy will be selected. Otherwise (or if it is the last provider), the error is raised normally. """ def __init__( self, providers: Mapping[str, Provider], strategy: FallbackStrategyFactory | None = None, *, same_provider: bool = False, ): self._providers = dict(providers) strategy_ = strategy if strategy is not None else PriorityFallback() self._strategy = strategy_.make_strategy(self._providers.keys()) self._same_provider = same_provider self._errors: dict[str, ProviderError] = {} self._lock = threading.Lock() @asynccontextmanager async def session(self) -> AsyncIterator["ProviderSession"]: async with AsyncExitStack() as stack: sessions = { provider_id: await stack.enter_async_context(provider.session()) for provider_id, provider in self._providers.items() } yield FallbackProviderSession( self, sessions, self._strategy, same_provider=self._same_provider, )
[docs] def errors(self) -> list[tuple[ProviderPath, ProviderError]]: """ Returns the list of recorded errors for sub-providers. Only the most recent error for every sub-provider is recorded. Querying this method clears the recorded errors. """ errors = [] with self._lock: for provider_id, error in self._errors.items(): errors.append((ProviderPath([provider_id]), error)) self._errors = {} for provider_id, provider in self._providers.items(): if isinstance(provider, FallbackProvider): sub_errors = provider.errors() for sub_path, error in sub_errors: errors.append((sub_path.group(provider_id), error)) return errors
def record_error(self, provider_id: str, exc: ProviderError) -> None: with self._lock: self._errors[provider_id] = exc
class FallbackProviderSession(ProviderSession): def __init__( self, provider: FallbackProvider, sessions: dict[str, ProviderSession], strategy: FallbackStrategy, *, same_provider: bool, ): self._provider = provider self._sessions = sessions self._strategy = strategy self._same_provider = same_provider async def rpc_and_pin(self, method: str, *args: RPC_JSON) -> tuple[RPC_JSON, ProviderPath]: provider_ids = self._strategy.get_provider_order() for i, provider_id in enumerate(provider_ids): session = self._sessions[provider_id] try: result, sub_path = await session.rpc_and_pin(method, *args) except ProviderError as exc: if not isinstance(session, FallbackProviderSession): self._provider.record_error(provider_id, exc) if i < len(provider_ids) - 1: continue raise break else: # pragma: no cover # This branch will never be reached, because the loop will either return, # or raise an exception. raise NotImplementedError return result, sub_path.group(provider_id) async def rpc(self, method: str, *args: RPC_JSON) -> RPC_JSON: result, _path = await self.rpc_and_pin(method, *args) return result async def rpc_at_pin(self, path: ProviderPath, method: str, *args: RPC_JSON) -> RPC_JSON: if self._same_provider: return await self.rpc(method, *args) if path.is_empty(): raise ValueError("Expected a non-empty provider path") provider_id, sub_path = path.ungroup() if provider_id not in self._sessions: raise ValueError(f"Provider id `{provider_id}` not found") return await self._sessions[provider_id].rpc_at_pin(sub_path, method, *args)