Pour tout problème contactez-nous par mail : support@froggit.fr | La FAQ :grey_question: | Rejoignez-nous sur le Chat :speech_balloon:

Skip to content
Snippets Groups Projects
Commit 558a173b authored by Dorian Turba's avatar Dorian Turba
Browse files

add an example test with async_fsm

parent 3e0cccdf
No related branches found
No related tags found
No related merge requests found
Pipeline #23878 failed
......@@ -26,12 +26,194 @@ class Base(sa.orm.DeclarativeBase):
def session_maker(cls) -> sqlalchemy.orm.sessionmaker:
return <namespace>.session_maker
class User(dal_poc.DML, Base): ...
"""
)
@classmethod
def async_session_maker(cls) -> sqlalchemy.ext.asyncio.async_sessionmaker:
raise NotImplementedError(
"""
Inherit from Base first,
and implement @classmethod async_session_maker() in the DeclarativeBase of your models:
class Base(sa.orm.DeclarativeBase):
@classmethod
def async_session_maker(cls) -> sqlalchemy.ext.asyncio.async_sessionmaker:
return <namespace>.async_session_maker
class User(dal_poc.DML, Base): ...
"""
)
class Select(SessionMakerMixin):
@classmethod
async def _async_select_result(
cls,
where: collections.abc.Sequence[sqlalchemy.ColumnExpressionArgument[bool]],
columns: collections.abc.Sequence[COLUMNS_CLAUSE_ARGUMENT],
limit: int,
) -> collections.abc.Sequence[sqlalchemy.Row]:
"""
Select rows of columns of a table.
A default limit of 500 is applied to prevent accidental large queries.
>>> User._select_result()
[(1, 'Jane'), (2, 'John')]
>>> User._select_result(columns=[User.name, User.id])
[('Jane', 1), ('John', 2)]
>>> User._select_result(where=[User.name == 'John'])
[(2, 'John')]
>>> User._select_result(limit=1)
[(1, 'Jane')]
>>> User._select_result(limit=-1)
[(1, 'Jane'), (2, 'John')]
>>> User._select_result(limit=1)[0].name
'Jane'
:param where: Mandatory where clause, explicit empty where clause is []. Default doesn't
filter.
:param columns: None for all columns, a specific column, or a non-empty sequence of columns.
Default is None.
:param limit: -1 for no limit, 0 for no rows, >0 for a limit. Default is 500.
:return: a sequence of rows of selected columns
"""
stmt = sqlalchemy.select(*columns).where(*where).limit(limit)
async with cls.async_session_maker() as session: # type: ignore
return (await session.execute(stmt)).all()
@classmethod
async def _async_select_scalar(
cls,
where: collections.abc.Sequence[sqlalchemy.ColumnExpressionArgument[bool]],
columns: COLUMNS_CLAUSE_ARGUMENT,
limit: int,
) -> collections.abc.Sequence[typing.Any]:
"""
Select a single column of a table.
The column's values are returned as sequence of values instead of sequence of rows.
>>> User._select_scalar(columns=User.name)
['Jane', 'John']
>>> User._select_scalar(columns=User.name, where=[User.name == 'John'])
['John']
>>> User._select_scalar(columns=User.name, limit=1)
['Jane']
>>> User._select_scalar(columns=User.name, limit=-1)
['Jane', 'John']
:param where: mandatory where clause, explicit empty where clause is []. Default doesn't
filter.
:param columns: None for all columns, a specific column, or a non-empty sequence of columns.
Default is None.
:param limit: -1 for no limit, 0 for no rows, >0 for a limit. Default is 500.
:return: a sequence of scalars of a single column
"""
stmt = sqlalchemy.select(columns).where(*where).limit(limit)
async with cls.async_session_maker() as session: # type: ignore
return (await session.scalars(stmt)).all()
@classmethod
@typing.overload
async def async_select(
cls,
where: collections.abc.Sequence[sqlalchemy.ColumnExpressionArgument[bool]],
columns: collections.abc.Sequence[COLUMNS_CLAUSE_ARGUMENT],
limit: int,
) -> collections.abc.Sequence[sqlalchemy.Row]:
...
@classmethod
@typing.overload
async def async_select(
cls,
where: collections.abc.Sequence[sqlalchemy.ColumnExpressionArgument[bool]],
columns: COLUMNS_CLAUSE_ARGUMENT,
limit: int,
) -> collections.abc.Sequence[typing.Any]:
...
@classmethod
async def async_select(
cls,
where: typing.Union[
collections.abc.Sequence[sqlalchemy.ColumnExpressionArgument[bool]], None
] = None,
columns: typing.Union[
collections.abc.Sequence[COLUMNS_CLAUSE_ARGUMENT], COLUMNS_CLAUSE_ARGUMENT, None
] = None,
limit: int = 500,
) -> typing.Union[
collections.abc.Sequence[sqlalchemy.Row],
collections.abc.Sequence[typing.Any],
]:
"""
Select rows of columns or scalars of a single column of a table.
A default limit of 500 is applied to prevent accidental large queries.
>>> await User.async_select()
[(1, 'Jane'), (2, 'John')]
>>> await User.async_select(columns=[User.name, User.id])
[('Jane', 1), ('John', 2)]
>>> await User.async_select(columns=[User.name])
[('Jane',), ('John',)]
>>> await User.async_select(where=[User.name == 'John'])
[(2, 'John')]
>>> await User.async_select(limit=1)
[(1, 'Jane')]
>>> await User.async_select(limit=-1)
[(1, 'Jane'), (2, 'John')]
>>> await User.async_select(columns=User.name)
['Jane', 'John']
>>> await User.async_select(columns=User.name, where=[User.name == 'John'])
['John']
>>> await User.async_select(columns=User.name, limit=1)
['Jane']
>>> await User.async_select(columns=User.name, limit=-1)
['Jane', 'John']
>>> await User.async_select(limit=1)[0].name
'Jane'
>>> await User.async_select(columns=[User.name])[0].name # Get a list of rows
'Jane'
>>> await User.async_select(columns=User.name)[0] # Get a list of names
'Jane'
:param where: mandatory where clause, explicit empty where clause is []. Default doesn't
filter.
:param columns: None for all columns, a specific column, or a non-empty sequence of columns.
Default is None.
:param limit: -1 for no limit, 0 for no rows, >0 for a limit. Default is 500.
:return: a sequence of rows of selected columns or a sequence of scalars of a single column
:raises ValueError: if columns is an empty sequence
:raises TypeError: if columns is not None, a specific column, or a non-empty sequence
"""
if where is None:
where = []
if isinstance(columns, collections.abc.Sequence) and len(columns) == 0:
raise ValueError("columns must be None, a specific column, or a non-empty sequence")
if isinstance(columns, collections.abc.Sequence):
return await cls._async_select_result(where, columns, limit)
if isinstance(columns, COLUMNS_CLAUSE_ARGUMENT):
return await cls._async_select_scalar(where, columns, limit)
if columns is None:
return await cls._async_select_result(
where,
cls.__table__.columns, # type: ignore
limit,
)
raise TypeError(
f"columns must be None, a specific column, or a non-empty sequence; got {columns!r}"
)
@classmethod
def _select_result(
cls,
......
......@@ -5,7 +5,7 @@ import sqlite3
import pydantic_settings
import pytest
import sqlalchemy.orm
from fake_session_maker import fsm
from fake_session_maker import async_fsm, fsm
import model_dml
......@@ -83,3 +83,16 @@ def fake_session_maker() -> sqlalchemy.orm.sessionmaker:
# the fake_session_maker won't auto-commit after transaction
# and rollback after transaction
yield fake_session_maker_
@pytest.fixture
def async_fake_session_maker() -> sqlalchemy.orm.sessionmaker:
with async_fsm(
db_url=settings.sqlalchemy_test_url,
namespace=Namespace,
symbol_name="session_maker",
create_engine_kwargs={"echo": True},
) as fake_session_maker_:
# the fake_session_maker won't auto-commit after transaction
# and rollback after transaction
yield fake_session_maker_
import pytest
import sqlalchemy.exc
pytest_plugins = ("pytest_asyncio",)
def test_select_all_columns(fake_session_maker, user):
with fake_session_maker.begin() as session:
......@@ -9,6 +11,14 @@ def test_select_all_columns(fake_session_maker, user):
assert user.select() == [(1, "John")]
@pytest.mark.asyncio
async def test_async_select_all_columns(async_fake_session_maker, user):
async with async_fake_session_maker.begin() as session:
session.add(user(name="John"))
assert (await user.select()) == [(1, "John")]
def test_select_one_column(fake_session_maker, user):
with fake_session_maker.begin() as session:
session.add(user(name="John"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment