"""
FastAPI Wrapper for vLLM Server
- API Key Authentication (Bearer Token)
- Rate Limiting
- CORS Support
- Temperature & n_samples support
"""
import os
import time
import hashlib
from typing import Optional, List, Dict, Any, Literal, Union
from datetime import datetime, timedelta
from contextlib import asynccontextmanager
import httpx
import redis.asyncio as redis
from fastapi import FastAPI, HTTPException, Security, Request, Depends, Response
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, Field
from dotenv import load_dotenv
from transformers import AutoTokenizer
load_dotenv()
VLLM_URL = os.getenv("VLLM_URL", "http://localhost:8000")
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
ALLOWED_API_KEYS = os.getenv("ALLOWED_API_KEYS", "").split(",")
ALLOWED_API_KEYS = [key.strip() for key in ALLOWED_API_KEYS if key.strip()]
EVALUATION_API_KEYS = os.getenv("EVALUATION_API_KEYS", "").split(",")
EVALUATION_API_KEYS = [key.strip() for key in EVALUATION_API_KEYS if key.strip()]
SAFETY_FILTER_ENABLED = os.getenv("SAFETY_FILTER_ENABLED", "True").lower() == "true"
BLOCKED_KEYWORDS = [
"폭탄", "폭발물", "화약", "독극물", "청산가리",
"마약", "필로폰", "메스암페타민", "코카인", "헤로인",
"총기", "권총", "소총",
"자살", "자해",
]
RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "60"))
RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60"))
CORS_ORIGINS = os.getenv("CORS_ORIGINS", "*").split(",")
CORS_ORIGINS = [origin.strip() for origin in CORS_ORIGINS]
MAX_PROMPT_TOKENS = int(os.getenv("MAX_PROMPT_TOKENS", "8000"))
ANOMALY_DUPLICATE_THRESHOLD = int(os.getenv("ANOMALY_DUPLICATE_THRESHOLD", "10"))
ANOMALY_DUPLICATE_WINDOW = int(os.getenv("ANOMALY_DUPLICATE_WINDOW", "300"))
redis_client: Optional[redis.Redis] = None
tokenizer = None
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""애플리케이션 생명주기 관리"""
global redis_client, tokenizer
try:
print("📥 Loading Qwen tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-4B-Instruct-2507",
trust_remote_code=True
)
print(f"✅ Loaded Qwen tokenizer (max_prompt_tokens: {MAX_PROMPT_TOKENS})")
except Exception as e:
print(f"⚠️ Failed to load tokenizer: {e}")
print("⚠️ Token validation will use fallback method")
tokenizer = None
try:
redis_client = await redis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
await redis_client.ping()
print(f"✅ Connected to Redis at {REDIS_URL}")
except Exception as e:
print(f"⚠️ Failed to connect to Redis: {e}")
print("⚠️ Falling back to in-memory rate limiting")
redis_client = None
yield
if redis_client:
await redis_client.close()
print("✅ Disconnected from Redis")
app = FastAPI(
title="vLLM Proxy API",
description="FastAPI wrapper for vLLM with authentication, rate limiting, and CORS",
version="1.0.1",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
security = HTTPBearer()
@app.exception_handler(HTTPException)
async def openai_compatible_exception_handler(request: Request, exc: HTTPException):
"""
HTTPException을 OpenAI SDK 호환 형식으로 변환
OpenAI SDK는 { "error": { "message": ..., "type": ..., "code": ... } } 형식을 기대함
"""
if isinstance(exc.detail, dict):
error_response = {
"error": {
"message": exc.detail.get("message", str(exc.detail)),
"type": exc.detail.get("error", "api_error"),
"code": exc.detail.get("error", "api_error"),
"param": None
}
}
else:
error_response = {
"error": {
"message": str(exc.detail),
"type": "api_error",
"code": None,
"param": None
}
}
if exc.status_code == 429:
error_response["error"]["type"] = "rate_limit_error"
elif exc.status_code == 401:
error_response["error"]["type"] = "authentication_error"
elif exc.status_code == 400:
error_response["error"]["type"] = "invalid_request_error"
return JSONResponse(
status_code=exc.status_code,
content=error_response,
headers=exc.headers
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Pydantic validation 에러를 OpenAI SDK 호환 형식으로 변환"""
return JSONResponse(
status_code=400,
content={
"error": {
"message": str(exc.errors()),
"type": "invalid_request_error",
"code": "validation_error",
"param": None
}
}
)
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""요청 body 로깅 (디버깅용)"""
if request.method == "POST":
try:
body = await request.body()
print(f"📥 Request to {request.url.path}")
print(f" Body: {body.decode('utf-8')[:500]}")
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive
except Exception as e:
print(f"⚠️ Failed to log request body: {e}")
response = await call_next(request)
return response
@app.middleware("http")
async def add_rate_limit_headers(request: Request, call_next):
"""Rate Limit 헤더를 모든 응답에 추가"""
response = await call_next(request)
if hasattr(request.state, "rate_limit_info"):
info = request.state.rate_limit_info
response.headers["X-RateLimit-Limit"] = str(RATE_LIMIT_REQUESTS)
response.headers["X-RateLimit-Remaining"] = str(info.get("remaining", 0))
response.headers["X-RateLimit-Reset"] = str(int(info.get("reset_time", 0)))
return response
class CompletionRequest(BaseModel):
"""vLLM Completion 요청 모델"""
model: str
prompt: str
max_tokens: Optional[int] = Field(default=100, ge=1, le=4096)
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
n: Optional[int] = Field(default=1, ge=1, le=10, description="Number of samples to generate (n_samples)")
stream: Optional[bool] = False
stop: Optional[List[str]] = None
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logprobs: Optional[int] = Field(default=None, ge=0, le=20, description="Number of log probabilities to return per token")
echo: Optional[bool] = Field(default=False, description="Echo back the prompt in addition to the completion")
class ChatCompletionContentPartText(BaseModel):
"""텍스트 content part"""
type: Literal["text"]
text: str
class ChatCompletionContentPartImage(BaseModel):
"""이미지 content part"""
type: Literal["image_url"]
image_url: Dict[str, str]
ChatCompletionContentPart = Union[ChatCompletionContentPartText, ChatCompletionContentPartImage]
class ChatMessage(BaseModel):
"""채팅 메시지 모델 (OpenAI compatible)"""
role: Literal["system", "user", "assistant", "tool", "function", "developer"]
content: Union[str, List[ChatCompletionContentPart]]
name: Optional[str] = None
class ChatCompletionRequest(BaseModel):
"""vLLM Chat Completion 요청 모델"""
model: str
messages: List[ChatMessage]
max_tokens: Optional[int] = Field(default=100, ge=1, le=4096)
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
n: Optional[int] = Field(default=1, ge=1, le=10, description="Number of samples to generate (n_samples)")
stream: Optional[bool] = False
stop: Optional[List[str]] = None
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logprobs: Optional[bool] = Field(default=False, description="Whether to return log probabilities")
top_logprobs: Optional[int] = Field(default=None, ge=0, le=20, description="Number of top log probabilities per token (requires logprobs=True)")
response_format: Optional[Dict[str, Any]] = Field(
default=None,
description="Response format for structured output. Example: {'type': 'json_object'} or {'type': 'json_schema', 'json_schema': {...}}"
)
guided_json: Optional[Dict[str, Any]] = Field(
default=None,
description="JSON schema for guided generation (vLLM specific)"
)
guided_regex: Optional[str] = Field(
default=None,
description="Regex pattern for guided generation (vLLM specific)"
)
guided_choice: Optional[List[str]] = Field(
default=None,
description="List of choices for guided generation (vLLM specific)"
)
class Config:
extra = "allow"
def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str:
"""
API 키 검증
- Bearer 토큰 형식으로 전달된 API 키를 검증
- ALLOWED_API_KEYS 또는 EVALUATION_API_KEYS 목록에 있는 키만 허용
"""
token = credentials.credentials
if not ALLOWED_API_KEYS and not EVALUATION_API_KEYS:
raise HTTPException(
status_code=500,
detail="Server configuration error: No API keys configured"
)
all_valid_keys = ALLOWED_API_KEYS + EVALUATION_API_KEYS
if token not in all_valid_keys:
raise HTTPException(
status_code=401,
detail="Invalid API key"
)
return token
async def check_rate_limit(api_key: str = Depends(verify_api_key)) -> Dict[str, Any]:
"""
Rate Limiting 체크 (Redis 기반)
- API 키별로 요청 횟수 제한
- 슬라이딩 윈도우 방식 사용
- Redis Sorted Set 활용 (score = timestamp)
- 반환: {"api_key": str, "remaining": int, "reset_time": float}
"""
if not redis_client:
return {
"api_key": api_key,
"remaining": RATE_LIMIT_REQUESTS,
"reset_time": time.time() + RATE_LIMIT_WINDOW
}
current_time = time.time()
key = f"rate_limit:{api_key}"
try:
pipe = redis_client.pipeline()
window_start = current_time - RATE_LIMIT_WINDOW
pipe.zremrangebyscore(key, 0, window_start)
pipe.zcard(key)
pipe.zadd(key, {str(current_time): current_time})
pipe.expire(key, RATE_LIMIT_WINDOW + 10)
results = await pipe.execute()
request_count = results[1]
remaining = max(0, RATE_LIMIT_REQUESTS - request_count)
reset_time = current_time + RATE_LIMIT_WINDOW
if request_count >= RATE_LIMIT_REQUESTS:
oldest_request = await redis_client.zrange(key, 0, 0, withscores=True)
if oldest_request:
oldest_timestamp = oldest_request[0][1]
retry_after = int(oldest_timestamp + RATE_LIMIT_WINDOW - current_time)
else:
retry_after = RATE_LIMIT_WINDOW
raise HTTPException(
status_code=429,
detail={
"error": "rate_limit_exceeded",
"message": f"Rate limit exceeded. Maximum {RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW} seconds.",
"retry_after": retry_after,
"limit": RATE_LIMIT_REQUESTS,
"window": RATE_LIMIT_WINDOW
},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": str(RATE_LIMIT_REQUESTS),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(reset_time))
}
)
return {
"api_key": api_key,
"remaining": remaining,
"reset_time": reset_time
}
except redis.RedisError as e:
print(f"⚠️ Redis error in rate limiting: {e}")
return {
"api_key": api_key,
"remaining": RATE_LIMIT_REQUESTS,
"reset_time": time.time() + RATE_LIMIT_WINDOW
}
def is_evaluation_mode(api_key: str) -> bool:
"""API 키가 평가 모드인지 확인"""
return api_key in EVALUATION_API_KEYS
def check_safety_filter(text: str, api_key: str) -> None:
"""
안전 필터 검사
- 차단 키워드가 포함된 요청 거부
- 평가 모드에서는 필터 우회 (로그만 기록)
"""
if not SAFETY_FILTER_ENABLED:
return
evaluation_mode = is_evaluation_mode(api_key)
for keyword in BLOCKED_KEYWORDS:
if keyword in text:
if evaluation_mode:
print(f"⚠️ [Evaluation Mode: True] Safety filter bypassed for keyword: {keyword}")
return
else:
raise HTTPException(
status_code=400,
detail={
"error": "content_policy_violation",
"message": "Your request contains content that violates our usage policy.",
}
)
def count_tokens(text: str) -> int:
"""
텍스트의 토큰 수 계산 (Qwen 기준)
- tokenizer 사용 가능 시: 정확한 토큰 수
- fallback: 대략적 계산 (한글 기준 1글자 ≈ 2토큰)
"""
if tokenizer:
try:
return len(tokenizer.encode(text))
except Exception as e:
print(f"⚠️ Tokenizer error: {e}, using fallback")
return len(text) // 2
def validate_prompt_tokens(prompt: str) -> None:
"""
Prompt 토큰 수 검증
- MAX_PROMPT_TOKENS 초과 시 HTTPException 발생
"""
token_count = count_tokens(prompt)
print(f'=======Token Count: {token_count}')
if token_count > MAX_PROMPT_TOKENS:
raise HTTPException(
status_code=422,
detail={
"error": "validation_error",
"message": f"Prompt too long. Please reduce to {MAX_PROMPT_TOKENS} tokens or less.",
"details": {
"current_tokens": token_count,
"max_tokens": MAX_PROMPT_TOKENS,
"suggestion": "Try shortening your prompt or splitting it into multiple requests"
}
}
)
async def detect_anomaly_pattern(api_key: str, prompt_hash: str) -> None:
"""
동일 프롬프트 중복 요청 탐지 (로깅 전용)
- 동일한 prompt_hash가 윈도우 내에서 반복되는지 감지
- Redis 기반, 로그만 기록
"""
if not redis_client:
return
try:
current_time = time.time()
key = f"anomaly:duplicate:{api_key}:{prompt_hash[:16]}"
pipe = redis_client.pipeline()
window_start = current_time - ANOMALY_DUPLICATE_WINDOW
pipe.zremrangebyscore(key, 0, window_start)
pipe.zcard(key)
pipe.zadd(key, {str(current_time): current_time})
pipe.expire(key, ANOMALY_DUPLICATE_WINDOW + 10)
results = await pipe.execute()
duplicate_count = results[1]
if duplicate_count >= ANOMALY_DUPLICATE_THRESHOLD:
masked_key = f"{api_key[:8]}..." if len(api_key) > 8 else "***"
print(f"⚠️ ANOMALY DETECTED: API key {masked_key} sent identical prompt {duplicate_count} times in {ANOMALY_DUPLICATE_WINDOW}s")
print(f" Prompt hash: {prompt_hash[:16]}...")
except redis.RedisError as e:
pass
@app.get("/")
async def root():
"""헬스 체크 엔드포인트"""
return {
"status": "healthy",
"service": "vLLM Proxy API",
"version": "1.0.0",
"vllm_url": VLLM_URL
}
@app.get("/health")
async def health_check():
"""상세 헬스 체크 - vLLM 서버 연결 확인"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{VLLM_URL}/health")
vllm_healthy = response.status_code == 200
except Exception as e:
vllm_healthy = False
return {
"proxy": "healthy",
"vllm_server": "healthy" if vllm_healthy else "unhealthy",
"timestamp": datetime.now().isoformat()
}
@app.post("/v1/completions")
async def completions(
req: Request,
request: CompletionRequest,
rate_limit_info: Dict[str, Any] = Depends(check_rate_limit)
):
"""
텍스트 완성 엔드포인트
- vLLM의 /v1/completions 엔드포인트로 프록시
- temperature, n (n_samples) 파라미터 지원
- 토큰 수 검증 및 이상 패턴 탐지
"""
api_key = rate_limit_info["api_key"]
req.state.rate_limit_info = rate_limit_info
evaluation_mode = is_evaluation_mode(api_key)
if evaluation_mode:
print(f"📋 [Evaluation Mode: True] Request from evaluation API key")
check_safety_filter(request.prompt, api_key)
validate_prompt_tokens(request.prompt)
prompt_hash = hashlib.sha256(request.prompt.encode()).hexdigest()
await detect_anomaly_pattern(api_key, prompt_hash)
try:
async with httpx.AsyncClient(timeout=300.0) as client:
response = await client.post(
f"{VLLM_URL}/v1/completions",
json=request.model_dump(),
headers={"Content-Type": "application/json"}
)
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=f"vLLM server error: {response.text}"
)
return response.json()
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Request to vLLM server timed out")
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Error connecting to vLLM server: {str(e)}")
@app.post("/v1/chat/completions")
async def chat_completions(
req: Request,
request: ChatCompletionRequest,
rate_limit_info: Dict[str, Any] = Depends(check_rate_limit)
):
"""
채팅 완성 엔드포인트
- vLLM의 /v1/chat/completions 엔드포인트로 프록시
- temperature, n (n_samples) 파라미터 지원
- 토큰 수 검증 및 이상 패턴 탐지
"""
api_key = rate_limit_info["api_key"]
req.state.rate_limit_info = rate_limit_info
evaluation_mode = is_evaluation_mode(api_key)
if evaluation_mode:
print(f"📋 [Evaluation Mode: True] Request from evaluation API key")
try:
async with httpx.AsyncClient(timeout=300.0) as client:
request_dict = request.model_dump()
normalized_messages = []
combined_text = []
for msg in request.messages:
msg_dict = msg.model_dump()
if isinstance(msg_dict["content"], list):
text_parts = []
for part in msg_dict["content"]:
if isinstance(part, dict) and part.get("type") == "text":
text_parts.append(part.get("text", ""))
msg_dict["content"] = "\n".join(text_parts)
if msg_dict.get("name") is None:
msg_dict.pop("name", None)
normalized_messages.append(msg_dict)
combined_text.append(msg_dict["content"])
request_dict["messages"] = normalized_messages
full_text = "\n".join(combined_text)
check_safety_filter(full_text, api_key)
validate_prompt_tokens(full_text)
prompt_hash = hashlib.sha256(full_text.encode()).hexdigest()
await detect_anomaly_pattern(api_key, prompt_hash)
print(f"📤 Sending to vLLM:")
print(f" Messages: {normalized_messages}")
response = await client.post(
f"{VLLM_URL}/v1/chat/completions",
json=request_dict,
headers={"Content-Type": "application/json"}
)
if response.status_code != 200:
print(f"❌ vLLM Error ({response.status_code}): {response.text}")
raise HTTPException(
status_code=response.status_code,
detail=f"vLLM server error: {response.text}"
)
return response.json()
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Request to vLLM server timed out")
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Error connecting to vLLM server: {str(e)}")
@app.get("/v1/models")
async def list_models(
req: Request,
rate_limit_info: Dict[str, Any] = Depends(check_rate_limit)
):
"""
사용 가능한 모델 목록 조회
- vLLM의 /v1/models 엔드포인트로 프록시
"""
req.state.rate_limit_info = rate_limit_info
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(
f"{VLLM_URL}/v1/models",
headers={"Content-Type": "application/json"}
)
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=f"vLLM server error: {response.text}"
)
return response.json()
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Request to vLLM server timed out")
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Error connecting to vLLM server: {str(e)}")
@app.get("/admin/rate-limits")
async def get_rate_limits(api_key: str = Depends(verify_api_key)):
"""
현재 Rate Limit 상태 조회 (관리자용, Redis 기반)
- 각 API 키별 현재 요청 수 확인
"""
if not redis_client:
return {"error": "Redis not connected", "fallback": "Rate limiting disabled"}
try:
current_time = time.time()
window_start = current_time - RATE_LIMIT_WINDOW
keys = await redis_client.keys("rate_limit:*")
stats = {}
for key in keys:
api_key_value = key.replace("rate_limit:", "")
count = await redis_client.zcount(key, window_start, current_time)
masked_key = f"{api_key_value[:8]}..." if len(api_key_value) > 8 else "***"
stats[masked_key] = {
"current_requests": count,
"limit": RATE_LIMIT_REQUESTS,
"window_seconds": RATE_LIMIT_WINDOW
}
return stats
except redis.RedisError as e:
raise HTTPException(status_code=500, detail=f"Redis error: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=os.getenv("HOST", "127.0.0.1"),
port=int(os.getenv("PORT", "8001")),
reload=os.getenv("RELOAD", "False").lower() == "true"
)