FastAPI의 OAuth2를 사용하여 기본적인 사용자 인증 시스템을 구축하는 방법에 대해 알아보자.
FastAPI의 OAuth2PasswordBearer
와 OAuth2PasswordRequestForm
을 사용하여 사용자의 로그인 요청을 처리하고, Pydantic 모델을 통해 사용자 데이터를 관리한다. 또한, 암호화된 비밀번호의 해싱과 검증, JWT의 생성 및 검증 과정을 구현함으로써 안전한 사용자 인증 과정을 구축한다.
models.py
: 사용자 모델을 정의.
여기에는 User
및 UserInDB
클래스가 포함된다. 이 파일은 데이터 모델과 관련된 모든 정의를 담당한다.
auth.py
: 사용자 인증과 관련된 함수를 포함.
get_user
와 fake_hash_password
같은 함수를 이 파일에 배치한다. 사용자 인증 및 관련 정보 검증을 담당한다.
token.py
: 토큰 관련 함수를 포함.
fake_decode_token
, get_current_user
, get_current_active_user
함수가 여기에 들어간다. 토큰 생성, 검증, 사용자 인증 상태 관리가 이루어진다.
routers.py
: API 엔드포인트를 정의.
/token
및 /users/me
같은 경로와 관련된 핸들러 함수를 이 파일에 넣는다. API 요청 및 응답 처리를 담당한다.
main.py
: FastAPI 앱 인스턴스를 생성하고 구성.
앱 실행과 관련된 기본 설정을 포함한다.
아래 코드 순서는 작성 순서대로 정리했다.
: FastAPI 애플리케이션 인스턴스 생성 및 라우터 설정
from fastapi import FastAPI
from routers import router
app = FastAPI()
app.include_router(router)
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", port=5005)
: User 및 UserInDB 모델 정의
from pydantic import BaseModel
# User 모델은 사용자에게 노출되는 정보를 포함합니다.
# 예를 들어, 사용자 이름, 이메일, 전체 이름과 같은 정보가 이에 해당합니다.
class User(BaseModel):
username: str
email: str | None = None
full_name: str | None = None
disabled: bool | None = None
# UserInDB 모델은 데이터베이스에 저장되는 사용자 정보를 포함합니다.
# 이 모델은 User 모델을 확장하고, 추가적으로 사용자의 해시된 비밀번호와 같은 민감한 정보를 포함합니다.
# 해시된 비밀번호와 같은 민감한 정보는 일반적으로 사용자에게 직접 노출되어서는 안 되는 정보입니다.
# UserInDB를 별도로 사용함으로써, 이러한 민감한 정보를 안전하게 관리할 수 있습니다.
class UserInDB(User):
hashed_password: str
# 사용자 인증이나 사용자 관련 로직을 처리할 때, UserInDB 모델에서 사용자의 해시된 비밀번호를 사용하고,
# 사용자에게 정보를 노출할 때는 User 모델을 사용합니다.
: 사용자 인증 및 비밀번호 관리
from schemas import UserInDB
fake_users_db = {
"johndoe": {
"username": "johndoe",
"full_name": "John Doe",
"email": "johndoe@example.com",
"hashed_password": "fakehashedsecret",
"disabled": False,
},
"alice": {
"username": "alice",
"full_name": "Alice Wonderson",
"email": "alice@example.com",
"hashed_password": "fakehashedsecret2",
"disabled": True,
},
}
def get_user(db, username: str):
if username in db:
return UserInDB(**db[username])
def fake_hash_password(password: str):
return "fakehashed" + password
: 토큰 생성 및 검증
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from auth import get_user, fake_users_db
from schemas import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def fake_decode_token(token):
return get_user(fake_users_db, token)
async def get_current_user(token: str = Depends(oauth2_scheme)):
user = fake_decode_token(token)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
return user
async def get_current_active_user(current_user: User = Depends(get_current_user)):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
: API 라우트 정의
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from auth_token import get_current_active_user
from auth import fake_users_db, fake_hash_password, UserInDB
from schemas import User
router = APIRouter()
@router.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
user_dict = fake_users_db.get(form_data.username)
if not user_dict:
raise HTTPException(status_code=400, detail="Incorrect username or password")
user = UserInDB(**user_dict)
hashed_password = fake_hash_password(form_data.password)
if hashed_password != user.hashed_password:
raise HTTPException(status_code=400, detail="Incorrect username or password")
return {"access_token": user.username, "token_type": "bearer"}
@router.get("/users/me")
async def read_users_me(current_user: User = Depends(get_current_active_user)):
return current_user
JWT token을 사용하기 위해서 필요한 패키지 설치한다.
JWT와 관련된 작업을 위해 python-jose와 passlib를 설치해야한다.
python-jose
> pip install 'python-jose[cryptography]'
python-jose
는 JSON Web Tokens(JWT)를 생성하고 검증하는 데 사용되는 라이브러리[cryptography]
는 선택적인 의존성으로, 이를 통해 보다 강력한 암호화 알고리즘과 기능을 사용할 수 있다.passlib
> pip install 'passlib[bcrypt]'
passlib
은 Python에서 비밀번호 해싱과 관리를 위한 라이브러리bcrypt
는 passlib
이 지원하는 여러 해싱 알고리즘 중 하나로, 특히 보안성이 높은 알고리즘으로 알려져있다.[bcrypt]
는 이 라이브러리에서 bcrypt 알고리즘을 사용할 수 있도록 하는 선택적인 의존성이다.FastAPI에서 위 두 라이브러리는 주로 사용자 인증과 관련된 기능을 구현하는 데 사용된다. python-jose
를 통해 안전하게 사용자 토큰을 생성하고 관리할 수 있으며, passlib
과 bcrypt
를 사용하여 사용자의 비밀번호를 안전하게 저장하고 검증할 수 있다.
: 보안 관련 함수 (JWT 토큰 생성 및 검증을 위한 유틸리티 함수 생성)
from passlib.context import CryptContext
from fastapi import HTTPException, status
from jose import JWTError
from typing import Annotated
from models import User, TokenData
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], fake_users_db=dict):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
user = get_user(fake_users_db, username=token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
: 데이터 모델과 관련된 클래스들을 정의한 모듈
from pydantic import BaseModel
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str | None = None
class User(BaseModel):
username: str
email: str | None = None
full_name: str | None = None
disabled: bool | None = None
class UserInDB(User):
hashed_password: str
: Create, Read, Update, Delete 작업을 수행하는 함수를 담은 모듈
from datetime import datetime, timedelta, timezone
from jose import jwt
from typing import Optional
from security import verify_password
from models import UserInDB
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
def authenticate_user(fake_db, username: str, password: str):
user = fake_db.get(username)
if not user or not verify_password(password, user.get("hashed_password")):
return None
return UserInDB(**user)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
: FastAPI 애플리케이션을 생성하고 설정하는 모듈
from fastapi import FastAPI
from fastapi.security import OAuth2PasswordBearer
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@app.post("/token")
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
) -> Token:
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return Token(access_token=access_token, token_type="bearer")
: 엔드포인트를 정의하는 모듈
from fastapi import Depends, HTTPException, status
from crud import read_own_items, read_users_me
from models import User
@app.get("/users/me/", response_model=User)
async def read_users_me(current_user: User = Depends(oauth2_scheme)):
return current_user
@app.get("/users/me/items/")
async def read_own_items(current_user: User = Depends(oauth2_scheme)):
return [{"item_id": "Foo", "owner": current_user.username}]
이렇게 코드를 분리하면 각각의 모듈이 명확한 역할을 가지게 되며 코드의 가독성과 유지 보수성이 향상된다.
덧붙여서 crud.py 모듈에서 OAuth2PasswordRequestForm, HTTPException, status, OAuth2PasswordBearer, Depends 등이 사용되는데 이들을 모듈에서 import하지 않고 있다. 따라서 이들을 적절히 import하여 모듈에 추가해줘야한다.
from datetime import timedelta
from fastapi import Depends, HTTPException, status, OAuth2PasswordRequestForm
from main import app, oauth2_scheme, fake_users_db, ACCESS_TOKEN_EXPIRE_MINUTES
from models import User
from security import create_access_token, authenticate_user
@app.post("/token")
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(data={"sub": user.username}, expires_delta=access_token_expires)
return {"access_token": access_token, "token_type": "bearer"}
FastAPI에서 WebSocket을 사용하여 채팅 서버를 구현하는 예제다. 클라이언트가 서버에 연결하고, 메시지를 보내고 받을 수 있는 간단한 채팅 서버 구현이 목표다.
: FastAPI 앱 생성 및 라우트 설정
from fastapi import FastAPI
from sockets import router as socket_router
from fastapi.responses import HTMLResponse
from pathlib import Path
app = FastAPI()
app.include_router(socket_router)
# 경로, HTML Response
@app.get("/")
def index():
index_html = Path('index.html').read_text()
return HTMLResponse(index_html)
: WebSocket 관련 코드
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from models import Message
from dependencies import get_current_username
from connection import manager
router = APIRouter()
@router.websocket("/ws/{token}")
async def websocket_endpoint(websocket: WebSocket, token: str):
username = get_current_username(token)
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
message = Message(username=username, text=data)
await manager.broadcast(message.json())
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"{username} left the chat")
except Exception as e:
await manager.broadcast(f"Error: {str(e)}")
: WebSocket 연결 관리자
from fastapi import WebSocket
from typing import List
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
: 데이터 모델
from pydantic import BaseModel
class Message(BaseModel):
username: str
text: str
: 의존성(인증 관련)
from fastapi import HTTPException
# 하드코딩된 사용자 데이터 예시
users = {
1: {"email": "user1@example.com"},
2: {"email": "user2@example.com"},
3: {"email": "user3@example.com"}
}
def get_current_username(token: str) -> str:
try:
user_id = int(token) # 토큰을 사용자 ID로 간주
except ValueError:
raise HTTPException(status_code=400, detail="Invalid token format")
user = users.get(user_id)
if user is None:
raise HTTPException(status_code=404, detail="User not found")
return user["email"] # 사용자의 이메일 반환
교재 코드 (DB 이용)
from sqlalchemy.orm import Session
from your_database_models import User # 데이터베이스 모델을 포함하는 모듈
from your_database import get_db # 데이터베이스 세션을 생성하는 함수
from fastapi import Depends, HTTPException
def get_current_username(token: str, db: Session = Depends(get_db)):
# 실제로는 토큰을 사용하여 사용자를 조회합니다.
# 여기서는 예시로 token을 사용자의 id로 가정합니다.
try:
user_id = int(token) # 예시로, 토큰을 사용자 ID로 간주합니다.
except ValueError:
raise HTTPException(status_code=400, detail="Invalid token format")
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise HTTPException(status_code=404, detail="User not found")
return user.email # 사용자의 이메일을 반환합니다.
ConnectionManager
클래스는 활성 WebSocket 연결을 관리한다.
/ws
경로에서 WebSocket 연결을 수락한다.
클라이언트로부터 메시지를 받으면 해당 메시지를 모든 클라이언트에게 방송한다.
클라이언트가 연결을 끊으면, 나머지 클라이언트들에게 알린다.
서버를 실행하기 위해 다음 명령어를 사용한다.
> uvicorn main:app --reload
클라이언트 측에서는 JavaScript를 사용하여 WebSocket 연결을 열고 메시지를 보낼 수 있다.
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Chat</title>
<script>
var websocket;
function connect() {
var token = document.getElementById("token").value;
websocket = new WebSocket("ws://127.0.0.1:8000/ws/" + token);
websocket.onmessage = function(event) {
var messages = document.getElementById("messages");
var message = JSON.parse(event.data);
var newMessage = document.createElement("div");
newMessage.appendChild(document.createTextNode(message.username + ": " + message.text));
messages.appendChild(newMessage);
};
websocket.onerror = function(event) {
console.error("WebSocket error observed:", event);
};
}
function sendMessage() {
var input = document.getElementById("messageText");
websocket.send(input.value);
input.value = '';
}
</script>
</head>
<body>
<h1>WebSocket Chat</h1>
<input type="text" id="token" placeholder="Enter token here">
<button onclick="connect()">Connect</button>
<div id="messages" style="border:1px solid #ccc; height:200px; overflow:auto; margin-bottom:10px;"></div>
<input type="text" id="messageText" placeholder="Enter message here">
<button onclick="sendMessage()">Send</button>
</body>
</html>
이미지 분류 API로 TensorFlow를 사용하여 이미지를 분류한다.
project_root/
│
├── app/ # 애플리케이션 폴더
│ ├── main.py # FastAPI 앱 생성 및 라우팅 설정
│ ├── models.py # 데이터 모델 정의
│ ├── dependencies.py # 의존성 관리 및 유틸리티 함수
│ └── routes/ # 엔드포인트와 관련된 라우팅 파일
│ └── image_classifier.py # 이미지 분류 API 라우팅
│
└── model/ # 모델 관련 파일
└── model_loader.py # 모델 로드 관련 함수
main.py
: FastAPI 앱 인스턴스를 생성하고 설정한다. 다른 라우팅 파일들을 앱에 포함시킨다.
models.py
: API에서 사용되는 데이터 모델들을 정의한다. 이 경우에는 별도의 데이터 모델이 필요하지 않을 수도 있다.
routes/image_classifier.py
: 이미지 분류 API의 라우팅과 로직을 정의한다. TensorFlow 모델을 사용하여 이미지를 분류하는 엔드포인트를 포함한다.
model/model_loader.py
: TensorFlow 모델을 로드하는 로직을 포함한다. 여기서는 MobileNetV2 모델을 로드하고, 이를 이미지 분류에 사용한다.
tensorflow
> pip install tensorflow
Pillow
> pip install Pillow
: 파이썬에서 이미지 파일을 다룰 수 있도록 도와주는 역할
: TensorFlow 모델 로드
# ts에서 model 불러오기
import tensorflow as tf
def load_model():
model = tf.keras.applications.MobileNetV2(weights="imagenet")
print("Model Load Successfully")
return model
model = load_model()
tf.keras.applications.MobileNetV2(weights="imagenet")
load_model
: 이미지 예측 함수
# Tensorflow -> Image Model 불러오기
# Pillow -> Image 관련 Module
from PIL.Image import Image
import numpy as np
from tensorflow.keras.applications.imagenet_utils import decode_predictions
from model_loader import model
# AI가 이해할 수 있는 데이터로 변경
def predict(image: Image):
image = np.asarray(image.resize((224, 224)))[..., :3] # RGB
image = np.expand_dims(image, 0) # 차원 확장
image = image / 127.5 -1.0 # Scaler(정규화) -> 이미지 데이터가 -1 ~ 1 형태 값으로 정규화
results = decode_predictions(model.predict(image), 3)[0]
print('results: ', results)
result_list = []
for i in results:
result_list.append({"class": i[1], "confidence": f"{i[2]*100:0.2f} %"})
return result_list
decode_predictions
predict
: API 엔드포인트
from fastapi import FastAPI, UploadFile, File
from PIL import Image
from io import BytesIO
from predict import predict
app = FastAPI()
@app.post('/predict/image')
async def predict_image(file: UploadFile=File(...)):
image = Image.open(BytesIO(await file.read()))
result = predict(image)
return result
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", reload=True)
@app.post("/predict/image")
: 이미지 파일을 받아 분류 결과를 반환하는 POST 요청 엔드포인트를 설정한다.
Image.open(BytesIO(await file.read()))
: 업로드된 이미지 파일을 읽어 PIL 이미지 객체로 변환한다.
강사님 코드
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from io import BytesIO
from predict import predict
app = FastAPI()
@app.post("/predict/image")
async def predict_api(file: UploadFile = File(...)):
extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
if not extension:
return "Image must be jpg or png format!"
image = Image.open(BytesIO(await file.read()))
prediction = predict(image)
return prediction
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, debug=True)
확장자를 넣어 해당 확장자가 아닐 경우 안내문구가 나오도록 해줬다.
ImageDetection 실습을 하는데 running time 오류가 있어서 pip install python-multipart
를 설치했다.
http://127.0.0.1:8000/docs
사이트에 접속해서 파일을 올리려고 하는데 TypeError: a bytes-like object is required, not 'coroutine'
가 발생해서 비동기 함수로 변경해줬다. (async def
와 await
추가)
고양이 사진을 올렸는데
이 고양이는 Egyptian 고양이로 판명...! 🐈