fastapi serving

Dorong·2026년 1월 12일

AI

목록 보기
2/5
"""
FastAPI Wrapper for vLLM Server
- API Key Authentication (Bearer Token)
- Rate Limiting
- CORS Support
- Temperature & n_samples support
"""

import os
import time
import hashlib
from typing import Optional, List, Dict, Any, Literal, Union
from datetime import datetime, timedelta
from contextlib import asynccontextmanager

import httpx
import redis.asyncio as redis
from fastapi import FastAPI, HTTPException, Security, Request, Depends, Response
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, Field
from dotenv import load_dotenv
from transformers import AutoTokenizer

# 환경 변수 로드
load_dotenv()

# 설정
VLLM_URL = os.getenv("VLLM_URL", "http://localhost:8000")
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
ALLOWED_API_KEYS = os.getenv("ALLOWED_API_KEYS", "").split(",")
ALLOWED_API_KEYS = [key.strip() for key in ALLOWED_API_KEYS if key.strip()]  # 쉼표 구분 API 키 공백 제거 및 빈 값 필터링

# 평가 모드 API 키 설정
EVALUATION_API_KEYS = os.getenv("EVALUATION_API_KEYS", "").split(",")
EVALUATION_API_KEYS = [key.strip() for key in EVALUATION_API_KEYS if key.strip()]

# 안전 필터 설정
SAFETY_FILTER_ENABLED = os.getenv("SAFETY_FILTER_ENABLED", "True").lower() == "true"
BLOCKED_KEYWORDS = [
    # 위험 물질
    "폭탄", "폭발물", "화약", "독극물", "청산가리",
    # 불법 약물
    "마약", "필로폰", "메스암페타민", "코카인", "헤로인",
    # 무기
    "총기", "권총", "소총",
    # 자해/타해
    "자살", "자해",
]

# Rate Limiting 설정
RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "60"))  # 분당 요청 수
RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60"))  # 윈도우 시간 (초)

# CORS 설정
CORS_ORIGINS = os.getenv("CORS_ORIGINS", "*").split(",")
CORS_ORIGINS = [origin.strip() for origin in CORS_ORIGINS]

# Token Limit 설정
MAX_PROMPT_TOKENS = int(os.getenv("MAX_PROMPT_TOKENS", "8000"))

# Anomaly Detection 설정
ANOMALY_DUPLICATE_THRESHOLD = int(os.getenv("ANOMALY_DUPLICATE_THRESHOLD", "10"))  # 5분 내 동일 요청 횟수
ANOMALY_DUPLICATE_WINDOW = int(os.getenv("ANOMALY_DUPLICATE_WINDOW", "300"))  # 5분 (초)

# 전역 변수
redis_client: Optional[redis.Redis] = None
tokenizer = None


# ============================================================
# Lifespan Event Handler
# ============================================================

@asynccontextmanager
async def lifespan(_app: FastAPI):
    """애플리케이션 생명주기 관리"""
    # Startup
    global redis_client, tokenizer

    # Tokenizer 로드
    try:
        print("📥 Loading Qwen tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(
            "Qwen/Qwen3-4B-Instruct-2507",
            trust_remote_code=True
        )
        print(f"✅ Loaded Qwen tokenizer (max_prompt_tokens: {MAX_PROMPT_TOKENS})")
    except Exception as e:
        print(f"⚠️  Failed to load tokenizer: {e}")
        print("⚠️  Token validation will use fallback method")
        tokenizer = None

    # Redis 연결
    try:
        redis_client = await redis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
        await redis_client.ping()
        print(f"✅ Connected to Redis at {REDIS_URL}")
    except Exception as e:
        print(f"⚠️  Failed to connect to Redis: {e}")
        print("⚠️  Falling back to in-memory rate limiting")
        redis_client = None

    yield

    # Shutdown
    if redis_client:
        await redis_client.close()
        print("✅ Disconnected from Redis")


# FastAPI 앱 초기화
app = FastAPI(
    title="vLLM Proxy API",
    description="FastAPI wrapper for vLLM with authentication, rate limiting, and CORS",
    version="1.0.1",
    lifespan=lifespan
)

# CORS 미들웨어 추가
app.add_middleware(
    CORSMiddleware,
    allow_origins=CORS_ORIGINS,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# HTTP Bearer 인증 스킴
security = HTTPBearer()


# ============================================================
# OpenAI Compatible Exception Handlers
# ============================================================

@app.exception_handler(HTTPException)
async def openai_compatible_exception_handler(request: Request, exc: HTTPException):
    """
    HTTPException을 OpenAI SDK 호환 형식으로 변환
    OpenAI SDK는 { "error": { "message": ..., "type": ..., "code": ... } } 형식을 기대함
    """
    if isinstance(exc.detail, dict):
        error_response = {
            "error": {
                "message": exc.detail.get("message", str(exc.detail)),
                "type": exc.detail.get("error", "api_error"),
                "code": exc.detail.get("error", "api_error"),
                "param": None
            }
        }
    else:
        error_response = {
            "error": {
                "message": str(exc.detail),
                "type": "api_error",
                "code": None,
                "param": None
            }
        }

    # status code에 따라 type 설정
    if exc.status_code == 429:
        error_response["error"]["type"] = "rate_limit_error"
    elif exc.status_code == 401:
        error_response["error"]["type"] = "authentication_error"
    elif exc.status_code == 400:
        error_response["error"]["type"] = "invalid_request_error"

    return JSONResponse(
        status_code=exc.status_code,
        content=error_response,
        headers=exc.headers
    )


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
    """Pydantic validation 에러를 OpenAI SDK 호환 형식으로 변환"""
    return JSONResponse(
        status_code=400,
        content={
            "error": {
                "message": str(exc.errors()),
                "type": "invalid_request_error",
                "code": "validation_error",
                "param": None
            }
        }
    )


# ============================================================
# Request Middleware
# ============================================================

@app.middleware("http")
async def log_requests(request: Request, call_next):
    """요청 body 로깅 (디버깅용)"""
    if request.method == "POST":
        try:
            body = await request.body()
            print(f"📥 Request to {request.url.path}")
            print(f"   Body: {body.decode('utf-8')[:500]}")  # 처음 500자만

            # body를 다시 읽을 수 있도록 재설정
            async def receive():
                return {"type": "http.request", "body": body}
            request._receive = receive
        except Exception as e:
            print(f"⚠️  Failed to log request body: {e}")

    response = await call_next(request)
    return response


@app.middleware("http")
async def add_rate_limit_headers(request: Request, call_next):
    """Rate Limit 헤더를 모든 응답에 추가"""
    response = await call_next(request)

    # rate_limit_info가 request.state에 저장된 경우 헤더 추가
    if hasattr(request.state, "rate_limit_info"):
        info = request.state.rate_limit_info
        response.headers["X-RateLimit-Limit"] = str(RATE_LIMIT_REQUESTS)
        response.headers["X-RateLimit-Remaining"] = str(info.get("remaining", 0))
        response.headers["X-RateLimit-Reset"] = str(int(info.get("reset_time", 0)))

    return response


# ============================================================
# Pydantic Models
# ============================================================

# Pydantic 클래스, 괄호 한은 상속받을 부모 클래스
# Pydantic 의 BaseModel 클래스는 크래스 변수 선언이 타입과 validation 스키마로 동작하게 함.
class CompletionRequest(BaseModel):
    """vLLM Completion 요청 모델"""
    model: str # 필수 필드
    prompt: str
    max_tokens: Optional[int] = Field(default=100, ge=1, le=4096) # 옵셔널 필드 + 런타임에서 최소 최댓값 검증 적용됨
    temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
    top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
    n: Optional[int] = Field(default=1, ge=1, le=10, description="Number of samples to generate (n_samples)")
    stream: Optional[bool] = False
    stop: Optional[List[str]] = None
    presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
    frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
    logprobs: Optional[int] = Field(default=None, ge=0, le=20, description="Number of log probabilities to return per token")
    echo: Optional[bool] = Field(default=False, description="Echo back the prompt in addition to the completion")


class ChatCompletionContentPartText(BaseModel):
    """텍스트 content part"""
    type: Literal["text"]
    text: str


class ChatCompletionContentPartImage(BaseModel):
    """이미지 content part"""
    type: Literal["image_url"]
    image_url: Dict[str, str]  # {"url": "..."}


ChatCompletionContentPart = Union[ChatCompletionContentPartText, ChatCompletionContentPartImage]


class ChatMessage(BaseModel):
    """채팅 메시지 모델 (OpenAI compatible)"""
    role: Literal["system", "user", "assistant", "tool", "function", "developer"]
    content: Union[str, List[ChatCompletionContentPart]]
    name: Optional[str] = None  # 참가자 이름 (선택)


class ChatCompletionRequest(BaseModel):
    """vLLM Chat Completion 요청 모델"""
    model: str
    messages: List[ChatMessage]
    max_tokens: Optional[int] = Field(default=100, ge=1, le=4096)
    temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
    top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
    n: Optional[int] = Field(default=1, ge=1, le=10, description="Number of samples to generate (n_samples)")
    stream: Optional[bool] = False
    stop: Optional[List[str]] = None
    presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
    frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
    logprobs: Optional[bool] = Field(default=False, description="Whether to return log probabilities")
    top_logprobs: Optional[int] = Field(default=None, ge=0, le=20, description="Number of top log probabilities per token (requires logprobs=True)")

    # Structured Output 지원 (OpenAI compatible)
    response_format: Optional[Dict[str, Any]] = Field(
        default=None,
        description="Response format for structured output. Example: {'type': 'json_object'} or {'type': 'json_schema', 'json_schema': {...}}"
    )

    # vLLM Guided Decoding (더 강력한 구조화 출력)
    guided_json: Optional[Dict[str, Any]] = Field(
        default=None,
        description="JSON schema for guided generation (vLLM specific)"
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description="Regex pattern for guided generation (vLLM specific)"
    )
    guided_choice: Optional[List[str]] = Field(
        default=None,
        description="List of choices for guided generation (vLLM specific)"
    )

    class Config:
        extra = "allow"  # 추가 필드 허용 (OpenAI SDK 호환성)


# ============================================================
# Authentication & Rate Limiting
# ============================================================

# Security(security) => fastAPI의 Dependency Injection
# Depends와 다른 점은 OAuth2 스코프(권한) 선언 기능이 추가된 보안 의존성 선언 함수라는 것.
# 현재 함수에 Security에는 scope 배열을 포함하고 있지는 않지만, 컨벤션 및 확장성, 문서화 등의 면에서 Security 사용
def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str:
    """
    API 키 검증
    - Bearer 토큰 형식으로 전달된 API 키를 검증
    - ALLOWED_API_KEYS 또는 EVALUATION_API_KEYS 목록에 있는 키만 허용
    """
    token = credentials.credentials

    if not ALLOWED_API_KEYS and not EVALUATION_API_KEYS:
        raise HTTPException(
            status_code=500,
            detail="Server configuration error: No API keys configured"
        )

    # 일반 키 또는 평가용 키 모두 허용
    all_valid_keys = ALLOWED_API_KEYS + EVALUATION_API_KEYS
    if token not in all_valid_keys:
        raise HTTPException(
            status_code=401,
            detail="Invalid API key"
        )

    return token


async def check_rate_limit(api_key: str = Depends(verify_api_key)) -> Dict[str, Any]:
    """
    Rate Limiting 체크 (Redis 기반)
    - API 키별로 요청 횟수 제한
    - 슬라이딩 윈도우 방식 사용
    - Redis Sorted Set 활용 (score = timestamp)
    - 반환: {"api_key": str, "remaining": int, "reset_time": float}
    """
    if not redis_client:
        # Redis 연결 실패 시 rate limiting 비활성화
        return {
            "api_key": api_key,
            "remaining": RATE_LIMIT_REQUESTS,
            "reset_time": time.time() + RATE_LIMIT_WINDOW
        }

    current_time = time.time()
    key = f"rate_limit:{api_key}"

    try:
        # Redis Pipeline으로 여러 명령을 원자적으로 실행
        pipe = redis_client.pipeline()

        # 1. 오래된 요청 제거 (윈도우 밖)
        window_start = current_time - RATE_LIMIT_WINDOW
        pipe.zremrangebyscore(key, 0, window_start)

        # 2. 현재 윈도우 내 요청 수 카운트
        pipe.zcard(key)

        # 3. 현재 요청 시간 추가
        pipe.zadd(key, {str(current_time): current_time})

        # 4. 키 만료 시간 설정 (윈도우 시간 + 여유)
        pipe.expire(key, RATE_LIMIT_WINDOW + 10)

        # 실행
        results = await pipe.execute()
        request_count = results[1]  # zcard 결과

        # 남은 요청 수 계산
        remaining = max(0, RATE_LIMIT_REQUESTS - request_count)

        # 리셋 시간 계산 (현재 윈도우가 끝나는 시간)
        reset_time = current_time + RATE_LIMIT_WINDOW

        # Rate limit 체크
        if request_count >= RATE_LIMIT_REQUESTS:
            # 얼마나 기다려야 하는지 계산
            oldest_request = await redis_client.zrange(key, 0, 0, withscores=True)
            if oldest_request:
                oldest_timestamp = oldest_request[0][1]
                retry_after = int(oldest_timestamp + RATE_LIMIT_WINDOW - current_time)
            else:
                retry_after = RATE_LIMIT_WINDOW

            raise HTTPException(
                status_code=429,
                detail={
                    "error": "rate_limit_exceeded",
                    "message": f"Rate limit exceeded. Maximum {RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW} seconds.",
                    "retry_after": retry_after,
                    "limit": RATE_LIMIT_REQUESTS,
                    "window": RATE_LIMIT_WINDOW
                },
                headers={
                    "Retry-After": str(retry_after),
                    "X-RateLimit-Limit": str(RATE_LIMIT_REQUESTS),
                    "X-RateLimit-Remaining": "0",
                    "X-RateLimit-Reset": str(int(reset_time))
                }
            )

        return {
            "api_key": api_key,
            "remaining": remaining,
            "reset_time": reset_time
        }

    except redis.RedisError as e:
        # Redis 오류 시 로깅하고 통과 (서비스 중단 방지)
        print(f"⚠️  Redis error in rate limiting: {e}")
        return {
            "api_key": api_key,
            "remaining": RATE_LIMIT_REQUESTS,
            "reset_time": time.time() + RATE_LIMIT_WINDOW
        }


# ============================================================
# Evaluation Mode & Safety Filter
# ============================================================

def is_evaluation_mode(api_key: str) -> bool:
    """API 키가 평가 모드인지 확인"""
    return api_key in EVALUATION_API_KEYS


def check_safety_filter(text: str, api_key: str) -> None:
    """
    안전 필터 검사
    - 차단 키워드가 포함된 요청 거부
    - 평가 모드에서는 필터 우회 (로그만 기록)
    """
    if not SAFETY_FILTER_ENABLED:
        return

    evaluation_mode = is_evaluation_mode(api_key)

    for keyword in BLOCKED_KEYWORDS:
        if keyword in text:
            if evaluation_mode:
                print(f"⚠️  [Evaluation Mode: True] Safety filter bypassed for keyword: {keyword}")
                return  # 평가 모드에서는 통과
            else:
                raise HTTPException(
                    status_code=400,
                    detail={
                        "error": "content_policy_violation",
                        "message": "Your request contains content that violates our usage policy.",
                    }
                )


# ============================================================
# Token Validation & Anomaly Detection
# ============================================================

def count_tokens(text: str) -> int:
    """
    텍스트의 토큰 수 계산 (Qwen 기준)
    - tokenizer 사용 가능 시: 정확한 토큰 수
    - fallback: 대략적 계산 (한글 기준 1글자 ≈ 2토큰)
    """
    if tokenizer:
        try:
            return len(tokenizer.encode(text))
        except Exception as e:
            print(f"⚠️  Tokenizer error: {e}, using fallback")

    # Fallback: 대략적 계산
    return len(text) // 2


def validate_prompt_tokens(prompt: str) -> None:
    """
    Prompt 토큰 수 검증
    - MAX_PROMPT_TOKENS 초과 시 HTTPException 발생
    """
    token_count = count_tokens(prompt)
    print(f'=======Token Count: {token_count}')

    if token_count > MAX_PROMPT_TOKENS:
        raise HTTPException(
            status_code=422,
            detail={
                "error": "validation_error",
                "message": f"Prompt too long. Please reduce to {MAX_PROMPT_TOKENS} tokens or less.",
                "details": {
                    "current_tokens": token_count,
                    "max_tokens": MAX_PROMPT_TOKENS,
                    "suggestion": "Try shortening your prompt or splitting it into multiple requests"
                }
            }
        )


async def detect_anomaly_pattern(api_key: str, prompt_hash: str) -> None:
    """
    동일 프롬프트 중복 요청 탐지 (로깅 전용)
    - 동일한 prompt_hash가 윈도우 내에서 반복되는지 감지
    - Redis 기반, 로그만 기록
    """
    if not redis_client:
        return

    try:
        current_time = time.time()
        # prompt_hash별로 별도 키 생성 → 동일 프롬프트만 카운팅
        key = f"anomaly:duplicate:{api_key}:{prompt_hash[:16]}"

        pipe = redis_client.pipeline()

        # 1. 오래된 요청 제거 (윈도우 밖)
        window_start = current_time - ANOMALY_DUPLICATE_WINDOW
        pipe.zremrangebyscore(key, 0, window_start)

        # 2. 현재 키의 요청 수 카운트 (= 동일 프롬프트 반복 횟수)
        pipe.zcard(key)

        # 3. 현재 요청 추가
        pipe.zadd(key, {str(current_time): current_time})

        # 4. 키 만료 시간 설정
        pipe.expire(key, ANOMALY_DUPLICATE_WINDOW + 10)

        results = await pipe.execute()
        duplicate_count = results[1]  # zcard 결과

        # 임계값 초과 시 경고 로그
        if duplicate_count >= ANOMALY_DUPLICATE_THRESHOLD:
            masked_key = f"{api_key[:8]}..." if len(api_key) > 8 else "***"
            print(f"⚠️  ANOMALY DETECTED: API key {masked_key} sent identical prompt {duplicate_count} times in {ANOMALY_DUPLICATE_WINDOW}s")
            print(f"   Prompt hash: {prompt_hash[:16]}...")

    except redis.RedisError as e:
        # Redis 오류 시 조용히 무시 (모니터링은 선택적 기능)
        pass


# ============================================================
# Health Check
# ============================================================

@app.get("/")
async def root():
    """헬스 체크 엔드포인트"""
    return {
        "status": "healthy",
        "service": "vLLM Proxy API",
        "version": "1.0.0",
        "vllm_url": VLLM_URL
    }


@app.get("/health")
async def health_check():
    """상세 헬스 체크 - vLLM 서버 연결 확인"""
    try:
        async with httpx.AsyncClient(timeout=5.0) as client:
            response = await client.get(f"{VLLM_URL}/health")
            vllm_healthy = response.status_code == 200
    except Exception as e:
        vllm_healthy = False

    return {
        "proxy": "healthy",
        "vllm_server": "healthy" if vllm_healthy else "unhealthy",
        "timestamp": datetime.now().isoformat()
    }


# ============================================================
# Proxy Endpoints
# ============================================================

@app.post("/v1/completions")
async def completions(
    req: Request,
    request: CompletionRequest,
    rate_limit_info: Dict[str, Any] = Depends(check_rate_limit)
):
    """
    텍스트 완성 엔드포인트
    - vLLM의 /v1/completions 엔드포인트로 프록시
    - temperature, n (n_samples) 파라미터 지원
    - 토큰 수 검증 및 이상 패턴 탐지
    """
    api_key = rate_limit_info["api_key"]
    req.state.rate_limit_info = rate_limit_info
    evaluation_mode = is_evaluation_mode(api_key)

    if evaluation_mode:
        print(f"📋 [Evaluation Mode: True] Request from evaluation API key")

    # 1. 안전 필터 검사
    check_safety_filter(request.prompt, api_key)

    # 2. Prompt 토큰 수 검증
    validate_prompt_tokens(request.prompt)

    # 3. 이상 패턴 탐지 (비동기, 로깅만)
    prompt_hash = hashlib.sha256(request.prompt.encode()).hexdigest()
    await detect_anomaly_pattern(api_key, prompt_hash)

    try:
        async with httpx.AsyncClient(timeout=300.0) as client:
            response = await client.post(
                f"{VLLM_URL}/v1/completions",
                json=request.model_dump(),
                headers={"Content-Type": "application/json"}
            )

            if response.status_code != 200:
                raise HTTPException(
                    status_code=response.status_code,
                    detail=f"vLLM server error: {response.text}"
                )

            return response.json()

    except httpx.TimeoutException:
        raise HTTPException(status_code=504, detail="Request to vLLM server timed out")
    except httpx.RequestError as e:
        raise HTTPException(status_code=502, detail=f"Error connecting to vLLM server: {str(e)}")


@app.post("/v1/chat/completions")
async def chat_completions(
    req: Request,
    request: ChatCompletionRequest,
    rate_limit_info: Dict[str, Any] = Depends(check_rate_limit)
):
    """
    채팅 완성 엔드포인트
    - vLLM의 /v1/chat/completions 엔드포인트로 프록시
    - temperature, n (n_samples) 파라미터 지원
    - 토큰 수 검증 및 이상 패턴 탐지
    """
    api_key = rate_limit_info["api_key"]
    req.state.rate_limit_info = rate_limit_info
    evaluation_mode = is_evaluation_mode(api_key)

    if evaluation_mode:
        print(f"📋 [Evaluation Mode: True] Request from evaluation API key")

    try:
        async with httpx.AsyncClient(timeout=300.0) as client:
            # messages를 dict로 변환하고, content를 string으로 정규화
            request_dict = request.model_dump()
            normalized_messages = []
            combined_text = []  # 토큰 수 계산용

            for msg in request.messages:
                msg_dict = msg.model_dump()

                # ChatCompletionMessageParam 타입 기준 파싱
                # content가 리스트인 경우 text만 추출하여 string으로 변환
                if isinstance(msg_dict["content"], list):
                    text_parts = []
                    for part in msg_dict["content"]:
                        if isinstance(part, dict) and part.get("type") == "text":
                            text_parts.append(part.get("text", ""))
                    msg_dict["content"] = "\n".join(text_parts)

                # name이 None이면 필드 제거 (vLLM이 None을 거부함)
                if msg_dict.get("name") is None:
                    msg_dict.pop("name", None)

                normalized_messages.append(msg_dict)
                combined_text.append(msg_dict["content"])

            request_dict["messages"] = normalized_messages

            # 1. 전체 메시지 텍스트 추출
            full_text = "\n".join(combined_text)

            # 2. 안전 필터 검사
            check_safety_filter(full_text, api_key)

            # 3. 토큰 수 검증
            validate_prompt_tokens(full_text)

            # 4. 이상 패턴 탐지 (비동기, 로깅만)
            prompt_hash = hashlib.sha256(full_text.encode()).hexdigest()
            await detect_anomaly_pattern(api_key, prompt_hash)

            # vLLM에 전송할 데이터 로깅
            print(f"📤 Sending to vLLM:")
            print(f"   Messages: {normalized_messages}")

            response = await client.post(
                f"{VLLM_URL}/v1/chat/completions",
                json=request_dict,
                headers={"Content-Type": "application/json"}
            )

            if response.status_code != 200:
                print(f"❌ vLLM Error ({response.status_code}): {response.text}")
                raise HTTPException(
                    status_code=response.status_code,
                    detail=f"vLLM server error: {response.text}"
                )

            return response.json()

    except httpx.TimeoutException:
        raise HTTPException(status_code=504, detail="Request to vLLM server timed out")
    except httpx.RequestError as e:
        raise HTTPException(status_code=502, detail=f"Error connecting to vLLM server: {str(e)}")


@app.get("/v1/models")
async def list_models(
    req: Request,
    rate_limit_info: Dict[str, Any] = Depends(check_rate_limit)
):
    """
    사용 가능한 모델 목록 조회
    - vLLM의 /v1/models 엔드포인트로 프록시
    """
    req.state.rate_limit_info = rate_limit_info

    try:
        async with httpx.AsyncClient(timeout=30.0) as client:
            response = await client.get(
                f"{VLLM_URL}/v1/models",
                headers={"Content-Type": "application/json"}
            )

            if response.status_code != 200:
                raise HTTPException(
                    status_code=response.status_code,
                    detail=f"vLLM server error: {response.text}"
                )

            return response.json()

    except httpx.TimeoutException:
        raise HTTPException(status_code=504, detail="Request to vLLM server timed out")
    except httpx.RequestError as e:
        raise HTTPException(status_code=502, detail=f"Error connecting to vLLM server: {str(e)}")


# ============================================================
# Admin Endpoints (Optional - for monitoring)
# ============================================================

@app.get("/admin/rate-limits")
async def get_rate_limits(api_key: str = Depends(verify_api_key)):
    """
    현재 Rate Limit 상태 조회 (관리자용, Redis 기반)
    - 각 API 키별 현재 요청 수 확인
    """
    if not redis_client:
        return {"error": "Redis not connected", "fallback": "Rate limiting disabled"}

    try:
        current_time = time.time()
        window_start = current_time - RATE_LIMIT_WINDOW

        # Redis에서 rate_limit:* 패턴의 모든 키 조회
        keys = await redis_client.keys("rate_limit:*")

        stats = {}
        for key in keys:
            # rate_limit:api_key 형태에서 api_key 추출
            api_key_value = key.replace("rate_limit:", "")

            # 현재 윈도우 내 요청 수 카운트
            count = await redis_client.zcount(key, window_start, current_time)

            # API 키 마스킹 (보안)
            masked_key = f"{api_key_value[:8]}..." if len(api_key_value) > 8 else "***"

            stats[masked_key] = {
                "current_requests": count,
                "limit": RATE_LIMIT_REQUESTS,
                "window_seconds": RATE_LIMIT_WINDOW
            }

        return stats

    except redis.RedisError as e:
        raise HTTPException(status_code=500, detail=f"Redis error: {str(e)}")


if __name__ == "__main__":
    import uvicorn

    # 서버 실행
    # 외부 접근을 허용하려면 host를 "0.0.0.0"으로 설정
    uvicorn.run(
        "main:app",
        host=os.getenv("HOST", "127.0.0.1"),  # 기본값은 로컬만, 외부 접근은 0.0.0.0
        port=int(os.getenv("PORT", "8001")),   # vLLM(8000)과 다른 포트 사용
        reload=os.getenv("RELOAD", "False").lower() == "true"
    )
profile
AI R&D와 웹/앱개발 욕심쟁이 멀티 플레이🐖

0개의 댓글