Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion burr/common/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

import inspect
from typing import AsyncGenerator, AsyncIterable, Generator, List, TypeVar, Union
from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Generator, List, TypeVar, Union

T = TypeVar("T")

Expand All @@ -27,6 +27,46 @@
SyncOrAsyncGeneratorOrItemOrList = Union[SyncOrAsyncGenerator[GenType], List[GenType], GenType]


class _AsyncPersisterContextManager:
"""Wraps an async coroutine that returns a persister so it can be used
directly with ``async with``::

async with AsyncSQLitePersister.from_values(...) as persister:
...

The wrapper awaits the coroutine on ``__aenter__`` and delegates
``__aexit__`` to the persister's own ``__aexit__``.

.. note::
Each instance wraps a single coroutine and can only be consumed once,
either via ``await`` or ``async with``. A second use will raise
``RuntimeError``.
"""

def __init__(self, coro: Coroutine[Any, Any, Any]):
self._coro = coro
self._persister = None
self._consumed = False

def __await__(self):
if self._consumed:
raise RuntimeError("This factory result has already been consumed")
self._consumed = True
return self._coro.__await__()

async def __aenter__(self):
if self._consumed:
raise RuntimeError("This factory result has already been consumed")
self._consumed = True
self._persister = await self._coro
return await self._persister.__aenter__()

async def __aexit__(self, exc_type, exc_value, traceback):
if self._persister is None:
return False
return await self._persister.__aexit__(exc_type, exc_value, traceback)


async def asyncify_generator(
generator: SyncOrAsyncGenerator[GenType],
) -> AsyncGenerator[GenType, None]:
Expand Down
35 changes: 27 additions & 8 deletions burr/integrations/persisters/b_aiosqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import aiosqlite

from burr.common.async_utils import _AsyncPersisterContextManager
from burr.common.types import BaseCopyable
from burr.core import State
from burr.core.persistence import AsyncBaseStatePersister, PersistedStateData
Expand Down Expand Up @@ -60,38 +61,56 @@ def copy(self) -> "Self":
PARTITION_KEY_DEFAULT = ""

@classmethod
async def from_config(cls, config: dict) -> "AsyncSQLitePersister":
def from_config(cls, config: dict) -> "_AsyncPersisterContextManager":
"""Creates a new instance of the AsyncSQLitePersister from a configuration dictionary.

Can be used with ``await`` or as an async context manager::

persister = await AsyncSQLitePersister.from_config(config)
# or
async with AsyncSQLitePersister.from_config(config) as persister:
...

The config key:value pair needed are:
db_path: str,
table_name: str,
serde_kwargs: dict,
connect_kwargs: dict,
"""
return await cls.from_values(**config)
return cls.from_values(**config)

@classmethod
async def from_values(
def from_values(
cls,
db_path: str,
table_name: str = "burr_state",
serde_kwargs: dict = None,
connect_kwargs: dict = None,
) -> "AsyncSQLitePersister":
) -> "_AsyncPersisterContextManager":
"""Creates a new instance of the AsyncSQLitePersister from passed in values.

Can be used with ``await`` or as an async context manager::

persister = await AsyncSQLitePersister.from_values(db_path="test.db")
# or
async with AsyncSQLitePersister.from_values(db_path="test.db") as persister:
...

:param db_path: the path the DB will be stored.
:param table_name: the table name to store things under.
:param serde_kwargs: kwargs for state serialization/deserialization.
:param connect_kwargs: kwargs to pass to the aiosqlite.connect method.
:return: async sqlite persister instance with an open connection. You are responsible
for closing the connection yourself.
"""
connection = await aiosqlite.connect(
db_path, **connect_kwargs if connect_kwargs is not None else {}
)
return cls(connection, table_name, serde_kwargs)

async def _create():
connection = await aiosqlite.connect(
db_path, **connect_kwargs if connect_kwargs is not None else {}
)
return cls(connection, table_name, serde_kwargs)

return _AsyncPersisterContextManager(_create())

def __init__(
self,
Expand Down
61 changes: 40 additions & 21 deletions burr/integrations/persisters/b_asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from burr.common.types import BaseCopyable
from burr.core import persistence, state
from burr.integrations import base
from burr.common.async_utils import _AsyncPersisterContextManager

try:
import asyncpg
Expand Down Expand Up @@ -106,12 +107,20 @@ async def create_pool(
return cls._pool

@classmethod
async def from_config(cls, config: dict) -> "AsyncPostgreSQLPersister":
"""Creates a new instance of the PostgreSQLPersister from a configuration dictionary."""
return await cls.from_values(**config)
def from_config(cls, config: dict) -> "_AsyncPersisterContextManager":
"""Creates a new instance of the PostgreSQLPersister from a configuration dictionary.

Can be used with ``await`` or as an async context manager::

persister = await AsyncPostgreSQLPersister.from_config(config)
# or
async with AsyncPostgreSQLPersister.from_config(config) as persister:
...
"""
return cls.from_values(**config)

@classmethod
async def from_values(
def from_values(
cls,
db_name: str,
user: str,
Expand All @@ -121,9 +130,16 @@ async def from_values(
table_name: str = "burr_state",
use_pool: bool = False,
**pool_kwargs,
) -> "AsyncPostgreSQLPersister":
) -> "_AsyncPersisterContextManager":
"""Builds a new instance of the PostgreSQLPersister from the provided values.

Can be used with ``await`` or as an async context manager::

persister = await AsyncPostgreSQLPersister.from_values(...)
# or
async with AsyncPostgreSQLPersister.from_values(...) as persister:
...

:param db_name: the name of the PostgreSQL database.
:param user: the username to connect to the PostgreSQL database.
:param password: the password to connect to the PostgreSQL database.
Expand All @@ -133,22 +149,25 @@ async def from_values(
:param use_pool: whether to use a connection pool (True) or a direct connection (False)
:param pool_kwargs: additional kwargs to pass to the pool creation
"""
if use_pool:
pool = await cls.create_pool(
user=user,
password=password,
database=db_name,
host=host,
port=port,
**pool_kwargs,
)
return cls(connection=None, pool=pool, table_name=table_name)
else:
# Original behavior - direct connection
connection = await asyncpg.connect(
user=user, password=password, database=db_name, host=host, port=port
)
return cls(connection=connection, table_name=table_name)

async def _create():
if use_pool:
pool = await cls.create_pool(
user=user,
password=password,
database=db_name,
host=host,
port=port,
**pool_kwargs,
)
return cls(connection=None, pool=pool, table_name=table_name)
else:
connection = await asyncpg.connect(
user=user, password=password, database=db_name, host=host, port=port
)
return cls(connection=connection, table_name=table_name)

return _AsyncPersisterContextManager(_create())

def __init__(
self,
Expand Down
8 changes: 4 additions & 4 deletions docs/concepts/parallelism.rst
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ When using state persistence with async parallelism, make sure to use the async
from burr.integrations.persisters.b_asyncpg import AsyncPGPersister

# Create an async persister with a connection pool
persister = AsyncPGPersister.from_values(
persister = await AsyncPGPersister.from_values(
host="localhost",
port=5432,
user="postgres",
Expand All @@ -707,7 +707,7 @@ When using state persistence with async parallelism, make sure to use the async
use_pool=True # Important for parallelism!
)

app = (
app = await (
ApplicationBuilder()
.with_state_persister(persister)
.with_action(
Expand All @@ -722,12 +722,12 @@ Remember to properly clean up your async persisters when you're done with them:

.. code-block:: python

# Using as a context manager
# Using as a context manager (recommended)
async with AsyncPGPersister.from_values(..., use_pool=True) as persister:
# Use persister here

# Or manual cleanup
persister = AsyncPGPersister.from_values(..., use_pool=True)
persister = await AsyncPGPersister.from_values(..., use_pool=True)
try:
# Use persister here
finally:
Expand Down
15 changes: 2 additions & 13 deletions tests/core/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,6 @@ def test_persister_methods_none_partition_key(persistence, method_name: str, kwa
"""Asyncio integration for sqlite persister + """


class AsyncSQLiteContextManager:
def __init__(self, sqlite_object):
self.client = sqlite_object

async def __aenter__(self):
return self.client

async def __aexit__(self, exc_type, exc, tb):
await self.client.close()


@pytest.fixture()
Expand Down Expand Up @@ -276,11 +267,9 @@ async def test_AsyncSQLitePersister_connection_shutdown():

@pytest.fixture()
async def initializing_async_persistence():
sqlite_persister = await AsyncSQLitePersister.from_values(
async with AsyncSQLitePersister.from_values(
db_path=":memory:", table_name="test_table"
)
async_context_manager = AsyncSQLiteContextManager(sqlite_persister)
async with async_context_manager as client:
) as client:
yield client


Expand Down
67 changes: 48 additions & 19 deletions tests/integrations/persisters/test_b_aiosqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,6 @@
from burr.integrations.persisters.b_aiosqlite import AsyncSQLitePersister


class AsyncSQLiteContextManager:
def __init__(self, sqlite_object):
self.client = sqlite_object

async def __aenter__(self):
return self.client

async def __aexit__(self, exc_type, exc, tb):
await self.client.cleanup()


async def test_copy_persister(async_persistence: AsyncSQLitePersister):
copy = async_persistence.copy()
assert copy.table_name == async_persistence.table_name
Expand All @@ -45,11 +34,9 @@ async def test_copy_persister(async_persistence: AsyncSQLitePersister):

@pytest.fixture()
async def async_persistence(request):
sqlite_persister = await AsyncSQLitePersister.from_values(
async with AsyncSQLitePersister.from_values(
db_path=":memory:", table_name="test_table"
)
async_context_manager = AsyncSQLiteContextManager(sqlite_persister)
async with async_context_manager as client:
) as client:
yield client


Expand Down Expand Up @@ -118,6 +105,50 @@ async def test_async_persister_methods_none_partition_key(
# these operations are stateful (i.e., read/write to a db)


async def test_async_sqlite_from_values_as_context_manager(tmp_path):
"""Test that from_values works directly with async with (issue #546)."""
db_path = str(tmp_path / "test.db")
async with AsyncSQLitePersister.from_values(db_path=db_path) as persister:
await persister.initialize()
await persister.save("pk", "app1", 1, "pos", State({"k": "v"}), "completed")
loaded = await persister.load("pk", "app1")
assert loaded is not None
assert loaded["state"] == State({"k": "v"})


async def test_async_sqlite_from_config_as_context_manager(tmp_path):
"""Test that from_config works directly with async with (issue #546)."""
db_path = str(tmp_path / "test.db")
config = {"db_path": db_path, "table_name": "burr_state"}
async with AsyncSQLitePersister.from_config(config) as persister:
await persister.initialize()
await persister.save("pk", "app1", 1, "pos", State({"k": "v"}), "completed")
loaded = await persister.load("pk", "app1")
assert loaded is not None


async def test_async_sqlite_from_values_cannot_be_consumed_twice():
"""Test that the factory wrapper raises on double consumption."""
wrapper = AsyncSQLitePersister.from_values(db_path=":memory:")
persister = await wrapper
with pytest.raises(RuntimeError, match="already been consumed"):
await wrapper
await persister.cleanup()


async def test_async_sqlite_context_manager_aexit_safe_on_failed_aenter(tmp_path):
"""Test that __aexit__ doesn't crash if __aenter__ never completed."""
from burr.common.async_utils import _AsyncPersisterContextManager

async def _failing_create():
raise ConnectionError("simulated connection failure")

mgr = _AsyncPersisterContextManager(_failing_create())
with pytest.raises(ConnectionError, match="simulated connection failure"):
async with mgr:
pass # should never reach here


async def test_AsyncSQLitePersister_from_values():
await asyncio.sleep(0.00001)
connection = await aiosqlite.connect(":memory:")
Expand Down Expand Up @@ -145,11 +176,9 @@ async def test_AsyncSQLitePersister_connection_shutdown():

@pytest.fixture()
async def initializing_async_persistence():
sqlite_persister = await AsyncSQLitePersister.from_values(
async with AsyncSQLitePersister.from_values(
db_path=":memory:", table_name="test_table"
)
async_context_manager = AsyncSQLiteContextManager(sqlite_persister)
async with async_context_manager as client:
) as client:
yield client


Expand Down