1. FastAPI Simple OAuth2

FastAPI의 OAuth2를 사용하여 기본적인 사용자 인증 시스템을 구축하는 방법에 대해 알아보자.

FastAPI의 OAuth2PasswordBearerOAuth2PasswordRequestForm을 사용하여 사용자의 로그인 요청을 처리하고, Pydantic 모델을 통해 사용자 데이터를 관리한다. 또한, 암호화된 비밀번호의 해싱과 검증, JWT의 생성 및 검증 과정을 구현함으로써 안전한 사용자 인증 과정을 구축한다.

  1. models.py
    : 사용자 모델을 정의.
    여기에는 UserUserInDB 클래스가 포함된다. 이 파일은 데이터 모델과 관련된 모든 정의를 담당한다.

  2. auth.py
    : 사용자 인증과 관련된 함수를 포함.
    get_userfake_hash_password 같은 함수를 이 파일에 배치한다. 사용자 인증 및 관련 정보 검증을 담당한다.

  3. token.py
    : 토큰 관련 함수를 포함.
    fake_decode_token, get_current_user, get_current_active_user 함수가 여기에 들어간다. 토큰 생성, 검증, 사용자 인증 상태 관리가 이루어진다.

  4. routers.py
    : API 엔드포인트를 정의.
    /token/users/me 같은 경로와 관련된 핸들러 함수를 이 파일에 넣는다. API 요청 및 응답 처리를 담당한다.

  5. main.py
    : FastAPI 앱 인스턴스를 생성하고 구성.
    앱 실행과 관련된 기본 설정을 포함한다.

아래 코드 순서는 작성 순서대로 정리했다.

01. main.py

: FastAPI 애플리케이션 인스턴스 생성 및 라우터 설정

from fastapi import FastAPI
from routers import router

app = FastAPI()
app.include_router(router)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("main:app", port=5005)

02. schemas.py

: User 및 UserInDB 모델 정의

from pydantic import BaseModel

# User 모델은 사용자에게 노출되는 정보를 포함합니다.
# 예를 들어, 사용자 이름, 이메일, 전체 이름과 같은 정보가 이에 해당합니다.
class User(BaseModel):
    username: str
    email: str | None = None
    full_name: str | None = None
    disabled: bool | None = None


# UserInDB 모델은 데이터베이스에 저장되는 사용자 정보를 포함합니다. 
# 이 모델은 User 모델을 확장하고, 추가적으로 사용자의 해시된 비밀번호와 같은 민감한 정보를 포함합니다.
# 해시된 비밀번호와 같은 민감한 정보는 일반적으로 사용자에게 직접 노출되어서는 안 되는 정보입니다. 
# UserInDB를 별도로 사용함으로써, 이러한 민감한 정보를 안전하게 관리할 수 있습니다.
class UserInDB(User):
    hashed_password: str


# 사용자 인증이나 사용자 관련 로직을 처리할 때, UserInDB 모델에서 사용자의 해시된 비밀번호를 사용하고, 
# 사용자에게 정보를 노출할 때는 User 모델을 사용합니다.

03. auth.py

: 사용자 인증 및 비밀번호 관리

from schemas import UserInDB

fake_users_db = {
    "johndoe": {
        "username": "johndoe",
        "full_name": "John Doe",
        "email": "johndoe@example.com",
        "hashed_password": "fakehashedsecret",
        "disabled": False,
    },
    "alice": {
        "username": "alice",
        "full_name": "Alice Wonderson",
        "email": "alice@example.com",
        "hashed_password": "fakehashedsecret2",
        "disabled": True,
    },
}

def get_user(db, username: str):
    if username in db:
        return UserInDB(**db[username])

def fake_hash_password(password: str):
    return "fakehashed" + password

04. auth_token.py

: 토큰 생성 및 검증

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from auth import get_user, fake_users_db
from schemas import User

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

def fake_decode_token(token):
    return get_user(fake_users_db, token)

async def get_current_user(token: str = Depends(oauth2_scheme)):
    user = fake_decode_token(token)
    if not user:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
    return user

async def get_current_active_user(current_user: User = Depends(get_current_user)):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user

05. routers.py

: API 라우트 정의

from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from auth_token import get_current_active_user
from auth import fake_users_db, fake_hash_password, UserInDB
from schemas import User

router = APIRouter()

@router.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
    user_dict = fake_users_db.get(form_data.username)
    if not user_dict:
        raise HTTPException(status_code=400, detail="Incorrect username or password")
    user = UserInDB(**user_dict)
    hashed_password = fake_hash_password(form_data.password)
    if hashed_password != user.hashed_password:
        raise HTTPException(status_code=400, detail="Incorrect username or password")
    return {"access_token": user.username, "token_type": "bearer"}

@router.get("/users/me")
async def read_users_me(current_user: User = Depends(get_current_active_user)):
    return current_user

2. FastAPI SQL

JWT token을 사용하기 위해서 필요한 패키지 설치한다.

00. 필요한 패키지 설치

JWT와 관련된 작업을 위해 python-jose와 passlib를 설치해야한다.

  • python-jose
    > pip install 'python-jose[cryptography]'
    • python-jose.PyPI
    • python-jose는 JSON Web Tokens(JWT)를 생성하고 검증하는 데 사용되는 라이브러리
    • JWT는 사용자 인증 및 정보 교환에 널리 사용되는 방식으로, 토큰 형태로 안전하게 정보를 전송할 수 있도록 한다.
    • 이 라이브러리는 JWT를 생성하고, 이를 서명하며, 또한 이러한 토큰을 해독하고 검증하는 데 필요한 기능을 제공한다.
    • [cryptography]는 선택적인 의존성으로, 이를 통해 보다 강력한 암호화 알고리즘과 기능을 사용할 수 있다.
  • passlib
    > pip install 'passlib[bcrypt]'
    • passlib은 Python에서 비밀번호 해싱과 관리를 위한 라이브러리
    • 비밀번호를 안전하게 저장하기 위해서는 단순한 저장 대신 해싱을 사용해야한다. 해싱은 비밀번호를 원래의 형태로 되돌릴 수 없는 데이터로 변환하는 과정이다.
    • bcryptpasslib이 지원하는 여러 해싱 알고리즘 중 하나로, 특히 보안성이 높은 알고리즘으로 알려져있다.
    • [bcrypt]는 이 라이브러리에서 bcrypt 알고리즘을 사용할 수 있도록 하는 선택적인 의존성이다.

FastAPI에서 위 두 라이브러리는 주로 사용자 인증과 관련된 기능을 구현하는 데 사용된다. python-jose를 통해 안전하게 사용자 토큰을 생성하고 관리할 수 있으며, passlibbcrypt를 사용하여 사용자의 비밀번호를 안전하게 저장하고 검증할 수 있다.

01. security.py

: 보안 관련 함수 (JWT 토큰 생성 및 검증을 위한 유틸리티 함수 생성)

from passlib.context import CryptContext
from fastapi import HTTPException, status
from jose import JWTError
from typing import Annotated
from models import User, TokenData

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)

def get_password_hash(password):
    return pwd_context.hash(password)
    
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], fake_users_db=dict):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
        token_data = TokenData(username=username)
    except JWTError:
        raise credentials_exception
    user = get_user(fake_users_db, username=token_data.username)
    if user is None:
        raise credentials_exception
    return user

async def get_current_active_user(
    current_user: Annotated[User, Depends(get_current_user)]
):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user

02. models.py

: 데이터 모델과 관련된 클래스들을 정의한 모듈

from pydantic import BaseModel

class Token(BaseModel):
    access_token: str
    token_type: str

class TokenData(BaseModel):
    username: str | None = None

class User(BaseModel):
    username: str
    email: str | None = None
    full_name: str | None = None
    disabled: bool | None = None

class UserInDB(User):
    hashed_password: str

03. crud.py

: Create, Read, Update, Delete 작업을 수행하는 함수를 담은 모듈

from datetime import datetime, timedelta, timezone
from jose import jwt
from typing import Optional

from security import verify_password
from models import UserInDB

SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

def authenticate_user(fake_db, username: str, password: str):
    user = fake_db.get(username)
    if not user or not verify_password(password, user.get("hashed_password")):
        return None
    return UserInDB(**user)

def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.now(timezone.utc) + expires_delta
    else:
        expire = datetime.now(timezone.utc) + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

04. main.py

: FastAPI 애플리케이션을 생성하고 설정하는 모듈

from fastapi import FastAPI
from fastapi.security import OAuth2PasswordBearer

app = FastAPI()

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

@app.post("/token")
async def login_for_access_token(
    form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
) -> Token:
    user = authenticate_user(fake_users_db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username}, expires_delta=access_token_expires
    )
    return Token(access_token=access_token, token_type="bearer")

05. routes.py

: 엔드포인트를 정의하는 모듈

from fastapi import Depends, HTTPException, status
from crud import read_own_items, read_users_me
from models import User

@app.get("/users/me/", response_model=User)
async def read_users_me(current_user: User = Depends(oauth2_scheme)):
    return current_user

@app.get("/users/me/items/")
async def read_own_items(current_user: User = Depends(oauth2_scheme)):
    return [{"item_id": "Foo", "owner": current_user.username}]

이렇게 코드를 분리하면 각각의 모듈이 명확한 역할을 가지게 되며 코드의 가독성과 유지 보수성이 향상된다.

+) 수정된 crud.py

덧붙여서 crud.py 모듈에서 OAuth2PasswordRequestForm, HTTPException, status, OAuth2PasswordBearer, Depends 등이 사용되는데 이들을 모듈에서 import하지 않고 있다. 따라서 이들을 적절히 import하여 모듈에 추가해줘야한다.

from datetime import timedelta
from fastapi import Depends, HTTPException, status, OAuth2PasswordRequestForm
from main import app, oauth2_scheme, fake_users_db, ACCESS_TOKEN_EXPIRE_MINUTES
from models import User
from security import create_access_token, authenticate_user

@app.post("/token")
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
    user = authenticate_user(fake_users_db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(data={"sub": user.username}, expires_delta=access_token_expires)
    return {"access_token": access_token, "token_type": "bearer"}

3. FastAPI Advanced

01. FastAPI Websocket

FastAPI에서 WebSocket을 사용하여 채팅 서버를 구현하는 예제다. 클라이언트가 서버에 연결하고, 메시지를 보내고 받을 수 있는 간단한 채팅 서버 구현이 목표다.

1) main.py

: FastAPI 앱 생성 및 라우트 설정

from fastapi import FastAPI
from sockets import router as socket_router
from fastapi.responses import HTMLResponse
from pathlib import Path

app = FastAPI()
app.include_router(socket_router)

# 경로, HTML Response
@app.get("/")
def index():
    index_html = Path('index.html').read_text()
    return HTMLResponse(index_html)

2) websockets.py

: WebSocket 관련 코드

from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from models import Message
from dependencies import get_current_username
from connection import manager

router = APIRouter()

@router.websocket("/ws/{token}")
async def websocket_endpoint(websocket: WebSocket, token: str):
    username = get_current_username(token)
    await manager.connect(websocket)
    try:
        while True:
            data = await websocket.receive_text()
            message = Message(username=username, text=data)
            await manager.broadcast(message.json())
    except WebSocketDisconnect:
        manager.disconnect(websocket)
        await manager.broadcast(f"{username} left the chat")
    except Exception as e:
        await manager.broadcast(f"Error: {str(e)}")

3) connection.py

: WebSocket 연결 관리자

from fastapi import WebSocket
from typing import List

class ConnectionManager:
    def __init__(self):
        self.active_connections: List[WebSocket] = []

    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.append(websocket)

    def disconnect(self, websocket: WebSocket):
        self.active_connections.remove(websocket)

    async def broadcast(self, message: str):
        for connection in self.active_connections:
            await connection.send_text(message)

manager = ConnectionManager()

4) models.py

: 데이터 모델

from pydantic import BaseModel

class Message(BaseModel):
    username: str
    text: str

5) dependencies.py

: 의존성(인증 관련)

from fastapi import HTTPException

# 하드코딩된 사용자 데이터 예시
users = {
    1: {"email": "user1@example.com"},
    2: {"email": "user2@example.com"},
    3: {"email": "user3@example.com"}
}

def get_current_username(token: str) -> str:
    try:
        user_id = int(token)  # 토큰을 사용자 ID로 간주
    except ValueError:
        raise HTTPException(status_code=400, detail="Invalid token format")

    user = users.get(user_id)
    if user is None:
        raise HTTPException(status_code=404, detail="User not found")

    return user["email"]  # 사용자의 이메일 반환
  • 교재 코드 (DB 이용)

    from sqlalchemy.orm import Session
    from your_database_models import User  # 데이터베이스 모델을 포함하는 모듈
    from your_database import get_db  # 데이터베이스 세션을 생성하는 함수
    from fastapi import Depends, HTTPException
    
    def get_current_username(token: str, db: Session = Depends(get_db)):
        # 실제로는 토큰을 사용하여 사용자를 조회합니다.
        # 여기서는 예시로 token을 사용자의 id로 가정합니다.
        try:
            user_id = int(token)  # 예시로, 토큰을 사용자 ID로 간주합니다.
        except ValueError:
            raise HTTPException(status_code=400, detail="Invalid token format")
    
        user = db.query(User).filter(User.id == user_id).first()
        if user is None:
            raise HTTPException(status_code=404, detail="User not found")
    
        return user.email  # 사용자의 이메일을 반환합니다.
  • ConnectionManager 클래스는 활성 WebSocket 연결을 관리한다.

  • /ws 경로에서 WebSocket 연결을 수락한다.

  • 클라이언트로부터 메시지를 받으면 해당 메시지를 모든 클라이언트에게 방송한다.

  • 클라이언트가 연결을 끊으면, 나머지 클라이언트들에게 알린다.

서버를 실행하기 위해 다음 명령어를 사용한다.

> uvicorn main:app --reload

클라이언트 측에서는 JavaScript를 사용하여 WebSocket 연결을 열고 메시지를 보낼 수 있다.

+) index.html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>Chat</title>
    <script>
        var websocket;

        function connect() {
            var token = document.getElementById("token").value;
            websocket = new WebSocket("ws://127.0.0.1:8000/ws/" + token);

            websocket.onmessage = function(event) {
                var messages = document.getElementById("messages");
                var message = JSON.parse(event.data);
                var newMessage = document.createElement("div");
                newMessage.appendChild(document.createTextNode(message.username + ": " + message.text));
                messages.appendChild(newMessage);
            };

            websocket.onerror = function(event) {
                console.error("WebSocket error observed:", event);
            };
        }

        function sendMessage() {
            var input = document.getElementById("messageText");
            websocket.send(input.value);
            input.value = '';
        }
    </script>
</head>
<body>
    <h1>WebSocket Chat</h1>
    <input type="text" id="token" placeholder="Enter token here">
    <button onclick="connect()">Connect</button>
    <div id="messages" style="border:1px solid #ccc; height:200px; overflow:auto; margin-bottom:10px;"></div>
    <input type="text" id="messageText" placeholder="Enter message here">
    <button onclick="sendMessage()">Send</button>
</body>
</html>

02. FastAPI AI (feat.ImageDetection)

이미지 분류 API로 TensorFlow를 사용하여 이미지를 분류한다.

- 프로젝트 파일 구조

project_root/
│
├── app/                       # 애플리케이션 폴더
│   ├── main.py                # FastAPI 앱 생성 및 라우팅 설정
│   ├── models.py              # 데이터 모델 정의
│   ├── dependencies.py        # 의존성 관리 및 유틸리티 함수
│   └── routes/                # 엔드포인트와 관련된 라우팅 파일
│       └── image_classifier.py # 이미지 분류 API 라우팅
│
└── model/                     # 모델 관련 파일
    └── model_loader.py        # 모델 로드 관련 함수

- 각 파일의 기능

  1. main.py
    : FastAPI 앱 인스턴스를 생성하고 설정한다. 다른 라우팅 파일들을 앱에 포함시킨다.

  2. models.py
    : API에서 사용되는 데이터 모델들을 정의한다. 이 경우에는 별도의 데이터 모델이 필요하지 않을 수도 있다.

  3. routes/image_classifier.py
    : 이미지 분류 API의 라우팅과 로직을 정의한다. TensorFlow 모델을 사용하여 이미지를 분류하는 엔드포인트를 포함한다.

  4. model/model_loader.py
    : TensorFlow 모델을 로드하는 로직을 포함한다. 여기서는 MobileNetV2 모델을 로드하고, 이를 이미지 분류에 사용한다.

- 라이브러리 설치

  • tensorflow

    > pip install tensorflow
  • Pillow

    > pip install Pillow

    : 파이썬에서 이미지 파일을 다룰 수 있도록 도와주는 역할

1) model_loader.py

: TensorFlow 모델 로드

# ts에서 model 불러오기
import tensorflow as tf

def load_model():
    model = tf.keras.applications.MobileNetV2(weights="imagenet")
    print("Model Load Successfully")
    return model

model = load_model()
  • tf.keras.applications.MobileNetV2(weights="imagenet")
    : MobileNetV2 모델을 불러온다. 이 모델은 ImageNet 데이터셋으로 사전 훈련되었다.
  • load_model
    : 모델을 로드하는 함수다.

2) predict.py

: 이미지 예측 함수

# Tensorflow -> Image Model 불러오기
# Pillow -> Image 관련 Module

from PIL.Image import Image
import numpy as np
from tensorflow.keras.applications.imagenet_utils import decode_predictions
from model_loader import model

# AI가 이해할 수 있는 데이터로 변경
def predict(image: Image):
    image = np.asarray(image.resize((224, 224)))[..., :3]   # RGB
    image = np.expand_dims(image, 0)    # 차원 확장
    image = image / 127.5 -1.0  # Scaler(정규화) -> 이미지 데이터가 -1 ~ 1 형태 값으로 정규화
    results = decode_predictions(model.predict(image), 3)[0]
    print('results: ', results)
    result_list = []
    for i in results:
        result_list.append({"class": i[1], "confidence": f"{i[2]*100:0.2f} %"})
    return result_list
  • decode_predictions
    : 모델이 예측한 결과를 해독하여 가장 가능성이 높은 클래스의 이름을 얻는다.
  • predict
    : 이미지를 처리하고 모델로 예측한 뒤, 결과를 반환하는 함수다.

3) main.py

: API 엔드포인트

from fastapi import FastAPI, UploadFile, File
from PIL import Image
from io import BytesIO
from predict import predict

app = FastAPI()

@app.post('/predict/image')
async def predict_image(file: UploadFile=File(...)):
    image = Image.open(BytesIO(await file.read()))
    result = predict(image)
    return result

if __name__ == "__main__":
    import uvicorn

    uvicorn.run("main:app", reload=True)
  • @app.post("/predict/image")
    : 이미지 파일을 받아 분류 결과를 반환하는 POST 요청 엔드포인트를 설정한다.

  • Image.open(BytesIO(await file.read()))
    : 업로드된 이미지 파일을 읽어 PIL 이미지 객체로 변환한다.

  • 강사님 코드

from fastapi import FastAPI, File, UploadFile
from PIL import Image
from io import BytesIO
from predict import predict

app = FastAPI()

@app.post("/predict/image")
async def predict_api(file: UploadFile = File(...)):
    extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
    if not extension:
        return "Image must be jpg or png format!"
    image = Image.open(BytesIO(await file.read()))
    prediction = predict(image)
    return prediction

if __name__ == "__main__":
		import uvicorn
    uvicorn.run(app, debug=True)

확장자를 넣어 해당 확장자가 아닐 경우 안내문구가 나오도록 해줬다.

[작업 결과💫]


[3일차 후기]

ImageDetection 실습을 하는데 running time 오류가 있어서 pip install python-multipart를 설치했다.

http://127.0.0.1:8000/docs 사이트에 접속해서 파일을 올리려고 하는데 TypeError: a bytes-like object is required, not 'coroutine'가 발생해서 비동기 함수로 변경해줬다. (async defawait 추가)

고양이 사진을 올렸는데

이 고양이는 Egyptian 고양이로 판명...! 🐈


[참고 자료]

  • [오즈스쿨 스타트업 웹 개발 초격차캠프 백엔드 Fast API 실시간 강의]
profile
백엔드 코린이😁

0개의 댓글

관련 채용 정보