Source code for asynctnt.connection

import asyncio
import enum
import functools
import os
from typing import Optional, Type, Union

from .api import Api
from .exceptions import ErrorCode, TarantoolDatabaseError, TarantoolError
from .iproto import protocol
from .log import logger
from .stream import Stream

__all__ = ("Connection", "connect", "ConnectionState")


[docs] class ConnectionState(enum.IntEnum):
[docs] CONNECTING = 1
[docs] CONNECTED = 2
[docs] RECONNECTING = 3
[docs] DISCONNECTING = 4
[docs] DISCONNECTED = 5
[docs] class Connection(Api):
[docs] __slots__ = ( "_host", "_port", "_username", "_password", "_fetch_schema", "_auto_refetch_schema", "_initial_read_buffer_size", "_encoding", "_connect_timeout", "_reconnect_timeout", "_request_timeout", "_ping_timeout", "_state", "_state_prev", "_transport", "_protocol", "_disconnect_waiter", "_reconnect_task", "_connect_lock", "_disconnect_lock", "_ping_task", )
def __init__( self, *, host: str = "127.0.0.1", port: Union[int, str] = 3301, username: Optional[str] = None, password: Optional[str] = None, fetch_schema: bool = True, auto_refetch_schema: bool = True, connect_timeout: float = 3.0, request_timeout: float = -1.0, reconnect_timeout: float = 1.0 / 3.0, ping_timeout: float = 5.0, encoding: Optional[str] = None, initial_read_buffer_size: Optional[int] = None, ): """ Connection constructor. To manipulate a Connection instance there are several functions: * await connect() - performs connecting, authorization and schema fetching. * await disconnect() - performs disconnection. * close() - closes connection (not a coroutine) Connection also supports context manager protocol, which connects on entering and disconnecting on leaving a block. So one can simply use it as follows: .. code-block:: python async with asynctnt.Connection() as conn: await conn.call('box.info') :param host: Tarantool host (pass ``unix/`` to connect to unix socket) :param port: Tarantool port (pass ``/path/to/sockfile`` to connect ot unix socket) :param username: Username to use for auth (if ``None`` you are connected as a guest) :param password: Password to use for auth :param fetch_schema: Pass ``True`` to be able to use spaces and indexes names in data manipulation routines (default is ``True``) :param auto_refetch_schema: If set to ``True`` then when ER_WRONG_SCHEMA_VERSION error occurs on a request, schema is refetched and the initial request is resent. If set to ``False`` then schema will not be checked by Tarantool, so no errors will occur :param connect_timeout: Time in seconds how long to wait for connecting to socket :param request_timeout: Request timeout (in seconds) for all requests (by default there is no timeout) :param reconnect_timeout: Time in seconds to wait before automatic reconnect (set to ``0`` or ``None`` to disable auto reconnect) :param ping_timeout: If specified (default is 5 seconds) a background task will be created which will ping Tarantool instance periodically to check if it is alive and update schema if it is changed (set to ``0`` or ``None`` to disable this task) :param encoding: The encoding to use for all strings encoding and decoding (default is ``utf-8``) :param initial_read_buffer_size: Initial and minimum size of read buffer in bytes. Higher value means less reallocations, but higher memory usage (default is 131072). """ super().__init__() self._host = host self._port = port self._username = username self._password = password self._fetch_schema = False if fetch_schema is None else fetch_schema if auto_refetch_schema: # None hack self._auto_refetch_schema = True if not self._fetch_schema: logger.warning( "Setting fetch_schema to True as " "auto_refetch_schema is True" ) self._fetch_schema = True else: self._auto_refetch_schema = False self._initial_read_buffer_size = initial_read_buffer_size self._encoding = encoding or "utf-8" self._connect_timeout = connect_timeout self._reconnect_timeout = reconnect_timeout or 0 self._request_timeout = request_timeout self._ping_timeout = ping_timeout or 0 self._transport = None self._protocol: Optional[protocol.Protocol] = None self._state = ConnectionState.DISCONNECTED self._state_prev = ConnectionState.DISCONNECTED self._disconnect_waiter = None self._reconnect_task = None self._connect_lock = asyncio.Lock() self._disconnect_lock = asyncio.Lock() self._ping_task = None
[docs] def _set_state(self, new_state: ConnectionState): if self._state != new_state: logger.debug("Changing state %s -> %s", self._state.name, new_state.name) self._state_prev = self._state self._state = new_state
[docs] def connection_lost(self, exc): if self._transport: self._transport.close() self._transport = None if self._disconnect_waiter: # disconnect() call happened self._disconnect_waiter.set_result(True) return # connection lost if self._reconnect_timeout > 0: # should reconnect self._start_reconnect(return_exceptions=False) else: # should not reconnect, close everything self.close()
[docs] async def _ping_task_func(self): while self._state == ConnectionState.CONNECTED: try: await self.ping(timeout=2.0) except asyncio.CancelledError: break except Exception: pass await asyncio.sleep(self._ping_timeout)
[docs] def _start_reconnect(self, return_exceptions: bool = False): if self._state in [ConnectionState.CONNECTING, ConnectionState.RECONNECTING]: logger.debug( "%s Cannot start reconnect: already reconnecting", self.fingerprint ) return if self._reconnect_task: # pragma: nocover return logger.info("%s Started reconnecting", self.fingerprint) self._set_state(ConnectionState.RECONNECTING) self._reconnect_task = asyncio.create_task( self._connect(return_exceptions=return_exceptions) )
[docs] def protocol_factory( self, connected_fut: asyncio.Future, loop: asyncio.AbstractEventLoop, cls: Type[protocol.Protocol] = protocol.Protocol, ): return cls( host=self._host, port=self._port, username=self._username, password=self._password, fetch_schema=self._fetch_schema, auto_refetch_schema=self._auto_refetch_schema, request_timeout=self._request_timeout, initial_read_buffer_size=self._initial_read_buffer_size, encoding=self._encoding, connected_fut=connected_fut, on_connection_made=None, on_connection_lost=self.connection_lost, loop=loop, )
[docs] async def _connect(self, return_exceptions: bool = True): loop = asyncio.get_running_loop() async with self._connect_lock: while True: try: ignore_states = { ConnectionState.CONNECTED, ConnectionState.DISCONNECTING, # disconnect() called } if self._state in ignore_states: self._reconnect_task = None return self._set_state(ConnectionState.CONNECTING) async def full_connect(): while True: connected_fut = loop.create_future() if self._host.startswith("unix/"): unix_path = self._port assert isinstance(unix_path, str), ( "port must be a str instance for " "unix socket" ) assert unix_path, "No unix file path specified" assert os.path.exists( unix_path ), "Unix socket `{}` not found".format(unix_path) conn = loop.create_unix_connection( functools.partial( self.protocol_factory, connected_fut, loop ), unix_path, ) else: conn = loop.create_connection( functools.partial( self.protocol_factory, connected_fut, loop ), self._host, self._port, ) tr, pr = await conn try: timeout = 0.05 # wait at least something if self._connect_timeout is not None: timeout = self._connect_timeout / 2 await asyncio.wait_for(connected_fut, timeout=timeout) except asyncio.TimeoutError: # pragma: nocover tr.close() continue # try again except Exception: tr.close() raise return tr, pr tr, pr = await asyncio.wait_for( full_connect(), timeout=self._connect_timeout, ) logger.info("%s Connected successfully", self.fingerprint) self._set_state(ConnectionState.CONNECTED) self._transport = tr self._protocol = pr self._set_db(self._protocol.get_common_db()) self._reconnect_task = None self._normalize_api() if self._ping_timeout: self._ping_task = loop.create_task(self._ping_task_func()) return except TarantoolDatabaseError as e: skip_errors = { ErrorCode.ER_LOADING, ErrorCode.ER_NO_SUCH_SPACE, ErrorCode.ER_NO_SUCH_INDEX_ID, } if e.code in skip_errors: # If Tarantool is still loading then reconnect if self._reconnect_timeout > 0: await self._wait_reconnect(e) continue if return_exceptions: self._reconnect_task = None raise e logger.exception(e) if self._reconnect_timeout > 0: await self._wait_reconnect(e) continue return # no reconnect, no return_exceptions except asyncio.CancelledError: logger.debug("connect is cancelled") self._reconnect_task = None raise except Exception as e: if self._reconnect_timeout > 0: await self._wait_reconnect(e) continue if return_exceptions: self._reconnect_task = None raise e logger.exception(e) return # no reconnect, no return_exceptions
[docs] async def _wait_reconnect(self, exc: Optional[Exception] = None): self._set_state(ConnectionState.RECONNECTING) logger.warning( "Connect to %s failed: %s. Retrying in %f seconds", self.fingerprint, repr(exc) if exc else "", self._reconnect_timeout, ) await asyncio.sleep(self._reconnect_timeout)
[docs] async def connect(self) -> "Connection": """ Connect coroutine """ await self._connect(True) return self
[docs] async def disconnect(self): """ Disconnect coroutine """ loop = asyncio.get_running_loop() async with self._disconnect_lock: if self._state == ConnectionState.DISCONNECTED: return self._set_state(ConnectionState.DISCONNECTING) logger.info("%s Disconnecting...", self.fingerprint) if self._reconnect_task: self._reconnect_task.cancel() self._reconnect_task = None if self._ping_task and not self._ping_task.done(): self._ping_task.cancel() self._ping_task = None self._clear_db() if self._transport: self._disconnect_waiter = loop.create_future() self._transport.close() self._transport = None self._protocol = None await self._disconnect_waiter self._disconnect_waiter = None self._set_state(ConnectionState.DISCONNECTED) else: self._transport = None self._protocol = None self._disconnect_waiter = None self._set_state(ConnectionState.DISCONNECTED)
[docs] def close(self): """ Same as disconnect, but not a coroutine, i.e. it does not wait for disconnect to finish. """ if self._state == ConnectionState.DISCONNECTED: return self._set_state(ConnectionState.DISCONNECTING) logger.info("%s Disconnecting...", self.fingerprint) if self._reconnect_task and not self._reconnect_task.done(): self._reconnect_task.cancel() self._reconnect_task = None if self._ping_task and not self._ping_task.done(): self._ping_task.cancel() self._ping_task = None if self._transport: self._transport.close() self._transport = None self._protocol = None self._disconnect_waiter = None self._clear_db() self._set_state(ConnectionState.DISCONNECTED)
[docs] async def reconnect(self): """ Reconnect coroutine. Just calls disconnect() and connect() """ await self.disconnect() await self.connect()
[docs] async def __aenter__(self) -> "Connection": """ Executed on entering the async with section. Connects to Tarantool instance. """ await self.connect() return self
[docs] async def __aexit__(self, exc_type, exc_val, exc_tb): """ Executed on leaving the async with section. Disconnects from Tarantool instance. """ await self.disconnect()
@property
[docs] def fingerprint(self) -> str: return "Tarantool[{}:{}]".format(self._host, self._port)
@property
[docs] def host(self) -> str: """ Tarantool host """ return self._host
@property
[docs] def port(self) -> int: """ Tarantool port """ return self._port
@property
[docs] def username(self) -> Optional[str]: """ Tarantool username """ return self._username
@property
[docs] def password(self) -> Optional[str]: """ Tarantool password """ return self._password
@property
[docs] def fetch_schema(self) -> bool: """ fetch_schema flag """ return self._fetch_schema
@property
[docs] def auto_refetch_schema(self) -> bool: """ auto_refetch_schema flag """ return self._auto_refetch_schema
@property
[docs] def encoding(self) -> str: """ Connection encoding """ return self._encoding
@property
[docs] def reconnect_timeout(self) -> float: """ Reconnect timeout value """ return self._reconnect_timeout
@property
[docs] def connect_timeout(self) -> float: """ Connect timeout value """ return self._connect_timeout
@property
[docs] def request_timeout(self) -> float: """ Request timeout value """ return self._request_timeout
@property
[docs] def version(self) -> Optional[tuple]: """ Protocol version tuple. ex.: (1, 6, 7) """ if self._protocol is None: return None return self._protocol.get_version()
@property
[docs] def state(self) -> ConnectionState: """ Current connection state :rtype: ConnectionState """ return self._state
@property
[docs] def is_connected(self) -> bool: """ Check if an underlying connection is active """ if self._protocol is None: return False return self._protocol.is_connected()
@property
[docs] def is_fully_connected(self) -> bool: """ Check if connection is fully active (performed auth and schema fetching) """ if self._protocol is None: return False return self._protocol.is_fully_connected()
@property
[docs] def schema_id(self) -> Optional[int]: """ Tarantool's current schema id """ if self._protocol is None: return None return self._protocol.schema_id
@property
[docs] def schema(self) -> Optional[protocol.Schema]: """ Current Tarantool schema with all spaces, indexes and fields """ if self._protocol is None: # pragma: nocover return None return self._protocol.schema
@property
[docs] def initial_read_buffer_size(self) -> int: """ initial_read_buffer_size value """ return self._initial_read_buffer_size
[docs] async def refetch_schema(self): """ Coroutine to force refetch schema """ await self._protocol.refetch_schema()
[docs] def _normalize_api(self): if (1, 6) <= self.version < (1, 7): # pragma: nocover Api.call = Api.call16 Connection.call = Connection.call16 if not self.features.streams: # pragma: nocover def stream_stub(_): raise TarantoolError("streams are available only in Tarantool 2.10+") Connection.stream = stream_stub
[docs] def __repr__(self): return "<asynctnt.Connection host={} port={} state={}>".format( self.host, self.port, repr(self.state) )
[docs] def stream(self) -> Stream: """ Create new stream suitable for interactive transactions """ stream = Stream() db = self._protocol.create_db(True) stream._set_db(db) return stream
@property
[docs] def features(self) -> protocol.IProtoFeatures: """ Lookup available Tarantool features - https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_iproto/feature/ :return: """ return self._protocol.features
[docs] async def connect(**kwargs) -> Connection: """ connect shorthand. See :class:`asynctnt.Connection` for kwargs details :return: :class:`asynctnt.Connection` object """ return await Connection(**kwargs).connect()