はじめに

FastAPISQLAlchemy を利用して Web API 開発を行っていた際、SQLAlchemy のマイグレーションツールである alembic を利用していました。

ただ E2E テストを書こうとした際に、pytest 実行中に alembic でデータベースマイグレーションを行う方法が分からず模索していました。結果的にマイグレーションのやり方は分かったものの一応今後も利用するかもしれないため、その内容を記事として残しておくことにしました。

本記事内で利用しているソースコードを含む FastAPI プロジェクトを GitHub リポジトリ上にアップしておいたので、詳細を確認されたい方がいればご参照くださいませ。

alembic でマイグレーションを行う

conftest.py にグローバルで利用するマイグレーション用の fixture を定義すれば OK です。

# conftest.py

import os

import alembic.config
import pytest
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import database_exists, create_database, drop_database

# テスト用の初期データを定義した module を import する (必要があれば)
# from .seed import users, contents


# 指定したパラメータを用いて alembic によるデータベースマイグレーションを行う
# 引数のデフォルト設定では全てのマイグレーションを実行するようになっている
def migrate(migrations_path, alembic_ini_path='alembic.ini', connection=None, revision="head"):
    config = alembic.config.Config(alembic_ini_path)
    config.set_main_option('script_location', migrations_path)
    if connection is not None:
        config.attributes['connection'] = connection
    alembic.command.upgrade(config, revision)


# テスト実行用にセットアップされたデータベースのセッション情報を扱う関数
# scope に session を指定することでテスト全体で一回だけ実行されるようにする
@pytest.fixture(scope="session", autouse=True)
def SessionLocal():
    test_sqlalchemy_database_url = os.environ['DATABASE_URL']
    engine = create_engine(test_sqlalchemy_database_url)

    # 既にテスト用データベースが存在していたら破棄する
    if database_exists(test_sqlalchemy_database_url):
        drop_database(test_sqlalchemy_database_url)

    # テスト用データベースを作成する
    create_database(test_sqlalchemy_database_url)

    # 環境変数 DATABASE_URL で指定したデータベースに対して、
    # マイグレーションを行いテスト実行に必要なテーブルを一括作成する
    # 第一引数に指定している alembic は `alembic init <環境名>` 実行時に指定した環境名を入力
    with engine.begin() as connection:
        migrate("alembic", 'alembic.ini', connection)

    Base = declarative_base()
    Base.metadata.create_all(engine)
    SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

    # テスト用の各種データを追加する (必要があれば)
    # db_session = SessionLocal()

    # for user in users:
    #     db_session.add(user)
    # db_session.commit()

    # for content in contents:
    #     db_session.add(content)
    # db_session.commit()

    # db_session.close()

    # テスト用データ追加後のセットアップ済みの状態で
    # テスト用に利用する SessionLocal を返却する
    yield SessionLocal

    # テストが全て終わったら、テスト用データベースを破棄して、
    # SQLAlchemy のセッションも切断する
    drop_database(test_sqlalchemy_database_url)
    engine.dispose()

FastAPI の pytest への適用例

上記を関数を利用する方法は各自のテスト環境によって異なると思いますが、一応私が FastAPI のテストコードを書く際に利用したソースコードを元に参考例を載せておきます。

conftest.py と同じディレクトリに client.py を作成します。

# client.py

from fastapi import Header, HTTPException, status
from fastapi.testclient import TestClient

from app.dependencies import get_database
from app.main import app


# conftest で定義した fixture の SessionLocal を元に、
# データベースセッションを作成するための override_get_db 関数を定義して、
# get_database の代わりに override_get_db を実行するよう差し替える
def temp_db(f):
    def func(SessionLocal, *args, **kwargs):
        def override_get_db():
            db = SessionLocal()
            try:
                yield db
            finally:
                db.close()

        app.dependency_overrides[get_database] = override_get_db
        f(*args, **kwargs)
        app.dependency_overrides[get_database] = get_database

    return func

client = TestClient(app)

あとは pytest のコード内で下記のような記述を行えば、FastAPI の内部でテスト用データベースを利用してくれるようになります。

from fastapi import status

from .client import client, temp_db


# temp_db fixture を定義しておくことで、
# 関数の実行中は FastAPI の内部でテスト用データベースを利用する
@temp_db
def test_read_me_token_valid():
    response = client.get("/users/me", headers={"Authorization": "Bearer 1234567890"})
    assert response.status_code == status.HTTP_200_OK

#...

参考リンク