Source code for shinto.pg.pool

"""Connection pool module, a wrapper around psycopg_pool ConnectionPool and AsyncConnectionPool."""

from __future__ import annotations

from contextlib import asynccontextmanager, contextmanager
from typing import TYPE_CHECKING

import psycopg_pool

from shinto.pg.connection import AsyncConnection, Connection

if TYPE_CHECKING:  # pragma: no cover
    from collections.abc import AsyncGenerator, Generator


[docs]class ConnectionPool(psycopg_pool.ConnectionPool): """ ConnectionPool class. Example: >>> pool = shinto.pg.ConnectionPool( ... min_size=1, ... max_size=10, ... kwargs={ ... "host": "localhost", ... "port": 6432, ... "database": "mydb", ... "user": "myuser", ... "password": "mypass", ... }, ... ) >>> with pool.connection() as conn: ... conn.execute_query("SELECT * FROM mytable") """
[docs] @contextmanager def connection(self, timeout: float | None = None) -> Generator[Connection, None, None]: """ Context manager to obtain a connection from the pool. Yields a custom Connection object that extends psycopg.Connection. Args: timeout (float | None): The maximum time to wait for a connection. Yields: Connection: A connection to the database. """ with super().connection(timeout) as conn: yield Connection(conn.pgconn, conn.row_factory)
[docs]class AsyncConnectionPool(psycopg_pool.AsyncConnectionPool): """ AsyncConnectionPool class. Example: >>> pool = shinto.pg.AsyncConnectionPool( ... min_size=1, ... max_size=10, ... kwargs={ ... "host": "localhost", ... "port": 6432, ... "dbname": "mydb", ... "user": "myuser", ... "password": "mypass", ... }, ... ) >>> async with pool.connection() as conn: ... await conn.execute_query("SELECT * FROM mytable") """
[docs] @asynccontextmanager async def connection( self, timeout: float | None = None, # noqa: ASYNC109, RUF100: Use `asyncio.timeout` instead. Not applicable for overriding method. ) -> AsyncGenerator[AsyncConnection, None, None]: """ Context manager to obtain an async connection from the pool. Yields a custom AsyncConnection object that extends psycopg.AsyncConnection. Args: timeout (float | None): The maximum time to wait for a connection. Yields: AsyncConnection: A connection to the database. """ async with super().connection(timeout) as conn: yield AsyncConnection(conn.pgconn, conn.row_factory)