[FastAPI] SQLAlchemy Wrapper

Hyeseong·2021년 5월 4일
2

들어가기 앞서

API 키 생성을 진행 하겠습니다.

본론

API키 생성을 위해서 routes/users.py 모듈에 관련 비동기 함수를 만들게 됩니다.

from typing import List
from uuid import uuid4

import bcrypt
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from starlette.requests import Request
from fastapi.logger import logger

from app.common.consts import MAX_API_KEY
from app.database.conn import db
from app.database.schema import Users, ApiKeys
from app import models as m
from app.errors import exceptions as ex
import string
import secrets

from app.models import MessageOk

router = APIRouter(prefix='/user')


@router.get('/me')
async def get_me(request: Request):
    user = request.state.user
    user_info = Users.get(id=user.id)
    return user_info


@router.put('/me')
async def put_me(request: Request):
    ...


@router.delete('/me')
async def delete_me(request: Request):
    ...


@router.get('/apikeys', response_model=List[m.GetApiKeyList])
async def get_api_keys(request: Request):
    user = request.state.user
    api_keys = ApiKeys.filter(user_id=user.id).filter(id__gt=1).all()
    return api_keys


@router.post('/apikeys', response_model=m.GetApiKeys)
async def create_api_keys(request: Request, key_info: m.AddApiKey, session: Session = Depends(db.session)):
    user = request.state.user

    api_keys = ApiKeys.filter(session, user_id=user.id, status='active').count()
    if api_keys == MAX_API_KEY:
        raise ex.MaxKeyCountEx()

    alphabet = string.ascii_letters + string.digits
    s_key = ''.join(secrets.choice(alphabet) for i in range(40))
    uid = f"{str(uuid4())[:-12]}{str(uuid4())}"

    key_info = key_info.dict()

    try:
        new_key = ApiKeys.create(session, auto_commit=True, secret_key=s_key, user_id=user.id, access_key=uid, **key_info)
    except Exception as e:
        raise ex.SqlFailureEx(e)
    return new_key


@router.delete('/apikeys/{key_id}')
async def delete_api_keys(request: Request, key_id: int, access_key: str):
    user = request.state.user
    key_data = ApiKeys.get(access_key=access_key)
    if key_data and key_data.id == key_id and key_data.user_id == user.id:
        ApiKeys.filter(id=key_id).delete(auto_commit=True)
        return MessageOk()
    else:
        raise ex.NoKeyMatchEx()

import된 것들 중 Users와 ApiKeys에 대해서 정리하고 가보려고합니다.

database/schema.py

from datetime import datetime, timedelta

from sqlalchemy import (
    Column,
    Integer,
    String,
    DateTime,
    func,
    Enum,
    Boolean,
    ForeignKey,
    JSON,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session

from app.database.conn import Base, db


class BaseMixin:
    id = Column(Integer, primary_key=True, index=True)
    created_at = Column(DateTime, nullable=False, default=func.utc_timestamp())
    updated_at = Column(DateTime, nullable=False, default=func.utc_timestamp(), onupdate=func.utc_timestamp())

    def __init__(self):
        self._q = None
        self._session = None

    def all_columns(self):
        return [c for c in self.__table__.columns if c.primary_key is False and c.name != "created_at"]

    def __hash__(self):
        return hash(self.id)

    @classmethod
    def create(cls, session: Session, auto_commit=False, **kwargs):
        """
        테이블 데이터 적재 전용 함수
        :param session:
        :param auto_commit: 자동 커밋 여부
        :param kwargs: 적재 할 데이터
        :return:
        """
        obj = cls()
        for col in obj.all_columns():
            col_name = col.name
            if col_name in kwargs:
                setattr(obj, col_name, kwargs.get(col_name))
        session.add(obj)
        session.flush()
        if auto_commit:
            session.commit()
        return obj

    @classmethod
    def get(cls, **kwargs):
        """
        Simply get a Row
        :param kwargs:
        :return:
        """
        session = next(db.session())
        query = session.query(cls)
        for key, val in kwargs.items():
            col = getattr(cls, key)
            query = query.filter(col == val)

        if query.count() > 1:
            raise Exception("Only one row is supposed to be returned, but got more than one.")
        result = query.first()
        session.close()
        return result


    @classmethod
    def filter(cls, session: Session = None, **kwargs):
        """
        Simply get a Row
        :param session:
        :param kwargs:
        :return:
        """
        cond = []
        for key, val in kwargs.items():
            key = key.split("__")
            if len(key) > 2:
                raise Exception("No 2 more dunders")
            col = getattr(cls, key[0])
            if len(key) == 1: cond.append((col == val))
            elif len(key) == 2 and key[1] == 'gt': cond.append((col > val))
            elif len(key) == 2 and key[1] == 'gte': cond.append((col >= val))
            elif len(key) == 2 and key[1] == 'lt': cond.append((col < val))
            elif len(key) == 2 and key[1] == 'lte': cond.append((col <= val))
            elif len(key) == 2 and key[1] == 'in': cond.append((col.in_(val)))

        obj = cls()
        if session:
            obj._session = session
            obj._sess_served = True
        else:
            obj._session = next(db.session())
            obj._sess_served = False
        query = obj._session.query(cls)
        query = query.filter(*cond)
        obj._q = query
        return obj


    @classmethod
    def cls_attr(cls, col_name=None):
        if col_name:
            col = getattr(cls, col_name)
            return col
        else:
            return cls

    def order_by(self, *args: str):
        for a in args:
            if a.startswith("-"):
                col_name = a[1:]
                is_asc = False
            else:
                col_name = a
                is_asc = True
            col = self.cls_attr(col_name)
            self._q = self._q.order_by(col.asc()) if is_asc else self._q.order_by(col.desc())
        return self


    def update(self, sess: Session = None, auto_commit: bool = False, **kwargs):
        cls = self.cls_attr()
        if sess:
            query = sess.query(cls)
        else:
            sess = next(db.session())
            query = sess.query(cls)
        self.close()
        return query.update(**kwargs)

    def first(self):
        result = self._q.first()
        self.close()
        return result

    def delete(self, auto_commit: bool = False, **kwargs):
        self._q.delete()
        if auto_commit:
            self._session.commit()
        self.close()

    def all(self):
        result = self._q.all()
        self.close()
        return result

    def count(self):
        result = self._q.count()
        self.close()
        return result

    def dict(self, *args: str):
        q_dict = {}
        for c in self.__table__.columns:
            if c.name in args:
                q_dict[c.name] = getattr(self, c.name)

        return q_dict

    def close(self):
        if self._sess_served:
            self._session.commit()
            self._session.close()
        else:
            self._session.flush()


class Users(Base, BaseMixin):
    __tablename__ = "users"
    status = Column(Enum("active", "deleted", "blocked"), default="active")
    email = Column(String(length=255), nullable=True)
    pw = Column(String(length=2000), nullable=True)
    name = Column(String(length=255), nullable=True)
    phone_number = Column(String(length=20), nullable=True, unique=True)
    profile_img = Column(String(length=1000), nullable=True)
    sns_type = Column(Enum("FB", "G", "K"), nullable=True)
    marketing_agree = Column(Boolean, nullable=True, default=True)


class ApiKeys(Base, BaseMixin):
    __tablename__ = "api_keys"
    access_key = Column(String(length=64), nullable=False, index=True)
    secret_key = Column(String(length=64), nullable=False)
    user_memo = Column(String(length=40), nullable=True)
    status = Column(Enum("active", "stopped", "deleted"), default="active")
    is_whitelisted = Column(Boolean, default=False)
    user_id = Column(Integer, ForeignKey("users.id"), nullable=False)

내부에는 많은 인스턴스 매소드와 클래스 메소드가 있는데요.
간단한 쿼리를 ORM으로 구현해서 간편하게 코드의 재사용성을 높이고자 작성한 부분이고요.
단점은 물론 복잡한 쿼리를 작성시 오히려 가독성이나 추후 유지 측면에서도 효율이 떨어지는데요. 그때는 별도로 복잡한 쿼리에 대한 작성을 하는 것이 더 낳아요.

데이터베이스에서 말하는 CRUD를 모두 구현은 해 둔 부분이조.
routes/users.py로 돌아가서 코드 하나하나씩 분석 할게요.

BaseMixin

get_me에서 User클래스의 get()메소드는 BaseMixin클래스의 @classmethod로 구현되어 있어요. 이외 다른 매직메소드, 클래스 메소드, 인스턴스 매소드들을 살펴 볼게요.

아래 15개의 메소드들이 BaseMixin을 통해 정의된게 보입니다.

get classmethod

무엇보다 get은 한개의 로우만 가져와야해요.
index key 또는 unique key로만 검사를 해야합니다. 만약 2개 이상 리턴 될 경우에는 데이터의 무결성에 문제가 생긴거구요.

그렇지만 그 경우 대비해서 exception 처리를 하도록 조건을 걸어뒀고요.

그럼 query.first()통해서 하나의 로우 정보를 result로 할당하고 세션을 닫은 후 이를 메소드에서 리턴해줘요.

filter classmethod

여러 로우를 가져오는데요.

id = 3, 3인값
idgt = 3 초과
id
gte = 3이상
idlt = 3 미만
id
lte = 3 이하
id__in = [?,?,?] 해당 값이 리스트에 포함

위와 같이 된 로직 구조에요.

특히 obj._q = query부분에서 클래스 변수인 _q부분을 통해서 실질적인 쿼리가 수행되요.

예)

  • self._q.count()
  • self._q.all()
  • self._q.delete()
  • self._q.first()
  • elf._q.order_by(col.asc()) if is_asc else self._q.order_by(col.desc())

참고로 SQLAlchemy에서 객체를 아무리 만들어도 생성자(init)을 만들지 않아요.(흥미로운 사실)
결론적으로 아예 더블언더 init메서드가 아예 없어도 상관 없다는 말이지만, 여기서는 명시적으로 표현하여 이해를 높이도록 한다고 생각하면되요.
조금더 부연 설명하면 만약 obj=cls()가 정의된 소스코드 바로 아래에 print(obj._q)를 찍으면 None이 나오는 것이 아닌 없다고 콘솔에서 얘기를 해줘요.

user.py

다시 돌아 와서 get_me 함수를 살펴볼게요.

27번째 줄을 보게 되면 마치 django의 향기가 물씬 나는게 느껴집니다.

조금더 비교를 해볼게요.
SQLAlchemy V.S 커스터마이징 ORM

코드의 전체적인 길이를 줄이는 측면과 중복된 클래스명과 속성명을 리팩토링한다는 측면에서 긍정적으로 볼 수 있어요.

profile
어제보다 오늘 그리고 오늘 보다 내일...

0개의 댓글