コンテンツへスキップ

SQLAlchemy 1.4の利用メモ

タグ:

https://www.sqlalchemy.org/

https://github.com/sqlalchemy/sqlalchemy/blob/main/lib/sqlalchemy/ext/horizontal_shard.py

水平分割(Horizontal shard)

SQLAlchemyには水平分割を実現するための機能が同梱されており、Sessionに指定する事で水平分割が実現できます。

SQLAlchemyの水平分割は、クエリの処理をフックすることで問い合わせ内容にあわせて適切なエンジンを割り当てる仕組みになっています。

水平分割機能は複数のengine(データベース)に問い合わせた結果を結合する処理が含まれるため、必ず単一の結果を期待する処理、例えばcount()といった機能は正常に動作しなくなります。

db.py

  • shard_chooser
    insert時に呼び出され、対象のengineを決定します。
  • id_chooser
    get時に呼び出され、対象のengineを決定します。
  • execute_chooser
    select時に呼び出され、対象のengineを決定します。
import typing
from sqlalchemy import create_engine
from sqlalchemy.sql import operators
from sqlalchemy.sql import visitors
from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.selectable import Select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.horizontal_shard import ShardedSession
import db_model

ENGINE_ECHO: bool = False
ENGINE_FUTURE: bool = True

dict_shard = {
    "group_0": create_engine(
        "sqlite:///database0.db", echo=ENGINE_ECHO, future=ENGINE_FUTURE
    ),
    "group_1": create_engine(
        "sqlite:///database1.db", echo=ENGINE_ECHO, future=ENGINE_FUTURE
    ),
    "group_2": create_engine(
        "sqlite:///database2.db", echo=ENGINE_ECHO, future=ENGINE_FUTURE
    ),
    "group_3": create_engine(
        "sqlite:///database3.db", echo=ENGINE_ECHO, future=ENGINE_FUTURE
    ),
}


def shard_chooser(mapper, instance, clause=None) -> int:
    if isinstance(instance, db_model.User):
        return typing.cast(db_model.User, instance).group_id
    else:
        return instance.group_id


def id_chooser(query, ident):
    if query.lazy_loaded_from:
        return [query.lazy_loaded_from.identity_token]
    else:
        return dict_shard.keys()


def _get_select_comparisons(statement):
    binds = {}
    clauses = set()
    comparisons = []

    def visit_bindparam(bind: BindParameter):
        binds[bind] = bind.effective_value

    def visit_column(column):
        clauses.add(column)

    def visit_binary(binary):
        if binary.left in clauses and binary.right in binds:
            comparisons.append((binary.left, binary.operator, binds[binary.right]))

        elif binary.left in binds and binary.right in clauses:
            comparisons.append((binary.right, binary.operator, binds[binary.left]))

    if isinstance(statement, TextClause) is True:
        return comparisons

    if isinstance(statement, Select) is True:
        if statement.whereclause is not None:
            for k, m in (
                ("column", visit_column),
                ("bindparam", visit_bindparam),
                ("binary", visit_binary),
            ):
                visitors.traverse(
                    statement.whereclause,
                    {},
                    {k: m},
                )

    return comparisons


def execute_chooser(context):
    list_shard = []
    for column, operator, value in _get_select_comparisons(context.statement):
        if column.shares_lineage(db_model.User.group_id):
            if operator == operators.eq:
                list_shard.append(value)
            elif operator == operators.in_op:
                list_shard.extend(v for v in value)

    if len(list_shard) == 0:
        return dict_shard.keys()
    else:
        return list_shard


Session = sessionmaker(class_=ShardedSession, future=True, shards=dict_shard)
Session.configure(
    shard_chooser=shard_chooser,
    id_chooser=id_chooser,
    execute_chooser=execute_chooser,
)

for o in dict_shard.values():
    db_model.Base.metadata.create_all(o)

db_model.py

from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import declarative_base

Base = declarative_base()


class User(Base):
    __tablename__ = "users"
    id = Column(Integer, primary_key=True, nullable=False)
    group_id = Column(String, nullable=False)
    name = Column(String, nullable=False)


class UserParam(Base):
    __tablename__ = "user_params"
    id = Column(Integer, primary_key=True, nullable=False)
    group_id = Column(String, ForeignKey("users.group_id"), nullable=False)
    param = Column(String, nullable=False)

main.py

import uuid
import random

import db
import db_model


def main():

    sess = db.Session(autocommit=False, autoflush=False)

    r_count = sum(r[0] for r in sess.execute("SELECT count(1) from users"))

    if r_count == 0:

        sess.rollback()
        sess.begin()

        for n in range(32):
            group_id = "group_{:d}".format(random.randint(0, 3))

            new_user = db_model.User(id=n, name=uuid.uuid4().hex, group_id=group_id)
            sess.add(new_user)

            new_user_param = db_model.UserParam(
                id=n, group_id=group_id, param=uuid.uuid4().hex
            )
            sess.add(new_user_param)

        sess.commit()

    r = sess.get(db_model.User, 1)
    print("sess.get  > {:02d}, {:s}, {:s}".format(r.id, r.group_id, r.name))

    query = sess.query(db_model.User).where(db_model.User.group_id == "group_1")
    for r in query:
        print("sess.query> {:02d}, {:s}, {:s}".format(r.id, r.group_id, r.name))

    total_record_count = 0
    for idx, r in enumerate(sess.execute("SELECT count(1) from users")):
        total_record_count += r[0]
    print("total_record_count:", total_record_count)


if __name__ == "__main__":
    main()