4. fastapi middleware 추가, response validation 하기

jw·2021년 6월 26일
1

fastapi 경험해보기

목록 보기
4/4
post-thumbnail
post-custom-banner

지난번 까지는 context manager를 사용한 session_scope 안에서 db 접근했었는데요. 이 경우 매 API 작성시 마다 추가해줘야 하기에, middleware 추가를 통해 이를 해결해보도록 하겠습니다.

main.py에 middleware 함수를 작성해보겠습니다

app = FastAPI() 라인과 @app.get('/') 라인 사이에 작성하겠습니다.

from fastapi import Request # 맨 위에 임포트해주세요


@app.middleware("http")
def access_db_middleware(request: Request, call_next):
    with session_scope() as session:
        request.state.session = session # 세션 생성 후 request.state에 저장하여 넘겨줌
        response = call_next(request)   # 다음 함수 실행
    return response                     # 결과 반환

처음엔 request.session으로 넘겨주려 했지만, 이미 session이라는 attribute가 존재했을 뿐 아니라, 추가한 정보가 보존되지 않더라고요..
찾아봤더니 starlette의 Request 객체를 사용하는데 해당 객체는 Request.state에 원하는 데이터를 저장할 수 있다고 합니다.

그럼 GET /diary 함수도 수정해줘야겠네요.

# 변경 전 
@app.get("/diary")
def get_all_diary():
    with session_scope() as session:
        diarys = session.query(Diary).all()
        return diarys
        
# 변경 후
@app.get("/diary")
def get_all_diary(request: Request):
    session = request.state.session
    diarys = session.query(Diary).all()
    return diarys

좀 더 깔끔하게 바뀌었죠?

fastAPI는 response model을 명시할 수 있는데 이 기능을 추가해 보도록 하겠습니다.
@app.get @app.post .. 등 데코레이터함수의 키워드 인자로 response_model 라는 이름으로 pydantic 모델을 넘겨주면 되는데요,

모델을 정의하기 위해 프로젝트 디렉토리에 models.py를 생성하도록 하겠습니다
제가 정의했던 Diary 스키마 기억나시죠?
해당 데이터를 아래와 같이 정의할 수 있겠네요.

# Diary 모델 정의 (models.py)

from datetime import datetime, date
from pydantic import BaseModel
from typing import Optional


class DiaryModel(BaseModel):
    id: str
    content: str
    created: datetime
    updated: Optional[datetime] = None # nullable column
    date: date

    class Config:              # 우리는 dict가 아니라 ORM으로 가져온 객체를 return 하게 되는데
        orm_mode = True        # 해당 내용이 없으면 dict가 아니라고 에러가 뜨게 됩니다.

이제 main.py에서 추가해보도록 하겠습니다
GET /diary의 return 값은 정확히는 DiaryModellist 이므로 아래와 같이 추가해 주면 됩니다.

from typing import List
from models import DiaryModel # 위에 추가해주세요

@app.get("/diary", response_model=List[DiaryModel])
def get_all_diary(request: Request):
    session = request.state.session
    diarys = session.query(Diary).all()
    return diarys

그냥 response_model=DiaryModel이라고 한 경우는 단일 객체를 의미하는데
typing.List를 사용해서 list를 표현할 수 있었네요.

오늘은 여기까지 입니다. 😉
이제 남은건 다른 API 추가네요!
다음 시간에 뵙겠습니다~

profile
개발 공부중입니다!
post-custom-banner

0개의 댓글