True sqlalchemy unit testing

tl;dr; Let’s mock the sqlalchemy calls and look at the generated SQL statements.

I’ve read a couple of articles about testing when sqlalchemy is involved. None of them really did unit testing: they all ended up doing local integration tests by spinning up a DB of some sort. While that can work pretty well in many many use cases (in-memory sqlite!), it can also be a PITA when you need to start a docker container, wait for it to boot, install plugins (moddatetime, pgtrimt, unaccent, …), instantiate the models, insert the data, and finally run the tests. While these integration tests have their value, in such slower setup there are benefits to constrain them to the bare minimum instead of being the default way of testing.

Let’s start. Here is some code taken from the sqlalchemy documentation. First, a basic User model.

# models.py
from sqlalchemy import String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column


class Base(DeclarativeBase):
    pass


class User(Base):
    __tablename__ = "user_account"

    id: Mapped[int] = mapped_column(primary_key=True)
    name: Mapped[str] = mapped_column(String(30))
    fullname: Mapped[str | None]

Now, let’s get inspired by code using the model, also from the documentation.

# main.py
# as taken from the documentation
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session

from models import User

engine = create_engine("sqlite://", echo=True)
session = Session(engine)

stmt = select(User).where(User.name.in_(["spongebob", "sandy"]))

for user in session.scalars(stmt):
    print(user)

And refactor it to enable us to demo better. It will

  1. Have a get_session function, that can be mocked.
  2. Constrain the “get users” functionality into a specialized function.
  3. Add an extra “print user” function, to demo an extra layer of test.
  4. Have a nice main, so that it can also be run. (Theoretically, it will not be covered.)
from typing import cast

from sqlalchemy import URL, create_engine, select
from sqlalchemy.orm import Session

from src.models import User

_session: Session | None


def create_session(url: URL) -> None:
    global _session
    engine = create_engine(url, echo=True)
    _session = Session(engine)


def get_session() -> Session:
    global _session
    if _session is None:
        raise ValueError("Session not created")
    return _session


def get_users(names: list[str]) -> list[User]:
    stmt = select(User).where(User.name.in_(names))

    session = get_session()
    return cast("list[User]", session.scalars(stmt))


def print_users(names: list[str]) -> None:
    for user in get_users(names):
         print(f"User: {user.id} - {user.name}")


if __name__ == "__main__":
    create_session(
        URL.create(
            "postgresql+psycopg",
            username="root",
            password="password",
            host="127.0.0.1",
            port=5432,
            database="postgres",
        )
    )
    print_users_by_name(["spongebob", "sandy"])

Now, on with the test file, using pytest.

from textwrap import dedent
from unittest.mock import MagicMock, patch

from sqlalchemy.dialects import postgresql

from src import main, models


@patch("src.main.get_session", autospec=True)
def test_get_users(mock_get_session: MagicMock) -> None:
    mock_get_session().scalars.return_value = [
        models.User(id=1, name="spongebob", fullname="Spongebob Squarepants"),
        models.User(id=2, name="sandy", fullname="Sandy Cheeks"),
    ]

    users = main.get_users(["spongebob", "sandy"])

    stmt = mock_get_session().scalars.call_args[0][0]
    compiled_stmt = stmt.compile(
        dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}
    )
    query = str(compiled_stmt).replace(" \n", "\n")
    assert query == dedent(
        """\
        SELECT user_account.id, user_account.name, user_account.fullname
        FROM user_account
        WHERE user_account.name IN ('spongebob', 'sandy')"""
    )
    assert len(users) == 2
    assert users[0].name == "spongebob"
    assert users[1].name == "sandy"

Let’s dig into the test function line by line.

On line 9, it patches the get_session function call with a MagicMock, and passes it as a first argument to the test function. (For more details, see unittest.mock.patch.) Also, I always use autospec=True to validate that the test calls actually use the right signature.

In main.py, the get_users function uses the session in the following statement: session.scalars(stmt). We’ll want this to return what the actual call to the database would return. This is what is being done at line 11 of the test file. (The current demo doesn’t emphasize this, but if the get_users function did some post treatment on the query output, this would be even more valuable.)

On line 16, the function we’re testing is called.

On line 18, the interesting bit begins. We get the actual statementthat was passed to the session.scalars call in main.py. (Even though the python type is sqlalchemy.sql.selectable.Select)

On line 20, the compile method is invoked over the stmt object, bringing it closer to an actual query string. The specified dialect is postgres, and compile_kwargs={"literal_binds": True} ensures that the query string contains actual values instead of placeholders. (The type of the compiled_stmt object is sqlalchemy.dialects.postgresql.base.PGCompiler, which is Postgres specific, given this is the compiler that was targeted.)

On line 22, the query string is retrieved by calling the __str__ method on compiled_stmt (via the str call), and then replace(" \n", "\n")makes some straight forward text replace to make it easier to test against and play nice with linters.

On line 23, the assertion is made on the query. 🎉 textwrap.dedent is used to make it clean.

The rest are regular assertions.

Had this function had the ability to generate a broader range of statements, more tests could be written using the same technique.

Next, let’s look at our call stack. If we were running this in a regular fashion (aka not for testing), we’d get:

  1. __main__
  2. print_users
  3. get_users

You probably don’t want to be mocking the internal bits of the get_users function every time a function depends on it. For that, you’ll want to mock it entirely.

from _pytest.capture import CaptureFixture


@patch(
    "src.main.get_users",
    autospec=True,
    return_value=[
        models.User(id=1, name="spongebob", fullname="Spongebob Squarepants"),
        models.User(id=2, name="sandy", fullname="Sandy Cheeks"),
    ],
)
def test_print_users(mock_get_users: MagicMock, capsys: CaptureFixture) -> None:
    main.print_users(["spongebob", "sandy"])
    out, err = capsys.readouterr()
    out_lines = out.split("\n")
    assert out_lines[0] == "User: 1 - spongebob"
    assert out_lines[1] == "User: 2 - sandy"
    assert err == ""
    mock_get_users.assert_called_once_with(["spongebob", "sandy"])

On line 4, the get_users function is mocked entirely.

On line 12, the capsys fixture is used. If you don’t know about it, it captures stdout and stderr so that it can be retrieved and asserted on. This is happening between lines 14 to 18.

On line 19, it asserts that the get_users function was called exactly once and with the expected parameters.

Afterthoughts

I’ve been using this technique for a while, and it definitely works.

But is it actually good?

To some level, I feel like it’s testing sqlalchemy’s ability to generate proper SQL. On the other hand, it may be the best / simplest we can do for true unit testing, not even involving a local server. I did not try, but maybe a different technique, like rebuilding ORM statements in the tests and asserting / comparing against them would work. What I do like about testing over the SQL string generation, is that we can also directly see how non natively mapped object behave. If we just made the same ORM statement in the tests, we would fail to capture how these objects are being translated into SQL. Arguable, this could also be tested separately.

How much do you think this is a good technique? Do you have a better one?

If you want to try it out, the source code is available here.

Featured image by xxaries1970xx. Source: https://www.deviantart.com/xxaries1970xx/art/Digital-Circuit-936780185

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.