TensorFlow 모델을 API로 만들어 웹과 연동하기

1rock·2025년 2월 18일

🚀 TensorFlow 모델을 API로 만들어 웹과 연동하기

TensorFlow를 이용해 간단한 예측 모델을 만들고, Flask를 사용하여 API 서버를 구축한 후 웹에서 호출하는 방법을 정리했습니다.


📌 전체 흐름

  1. TensorFlow 모델 만들기 (간단한 선형 회귀)
  2. 모델을 학습하고 저장 (model.h5)
  3. Flask로 API 서버 만들기
  4. 웹에서 API 호출하여 예측값 받아오기

1️⃣ TensorFlow 모델 만들기

먼저, y = 2x 형태의 데이터를 학습하는 간단한 모델을 만들고 저장합니다.

🛠️ TensorFlow 설치

pip install tensorflow

📝 모델 코드 (train.py)

import tensorflow as tf
import numpy as np

# 1. 데이터 준비
x_train = np.array([1, 2, 3, 4, 5], dtype=np.float32)
y_train = np.array([2, 4, 6, 8, 10], dtype=np.float32)  # y = 2x

# 2. 모델 생성
model = tf.keras.Sequential([
    tf.keras.layers.Dense(units=1, input_shape=[1])  # 입력 1개, 출력 1개
])

# 3. 모델 컴파일
model.compile(optimizer='sgd', loss='mean_squared_error')

# 4. 모델 학습
model.fit(x_train, y_train, epochs=500, verbose=0)

# 5. 모델 저장
model.save("model.h5")
print("✅ 모델 저장 완료!")

💡 설명

  • x_train, y_train 데이터를 사용해 y = 2x 패턴을 학습
  • model.h5 파일로 모델을 저장 (Flask API에서 불러와 사용하기 위함)

2️⃣ Flask API 서버 만들기

이제 Flask를 이용하여 API 서버를 구축합니다.

🛠️ Flask 설치

pip install flask tensorflow

📝 Flask 서버 코드 (app.py)

from flask import Flask, request, jsonify
import tensorflow as tf
import numpy as np

# 1. Flask 앱 생성
app = Flask(__name__)

# 2. 저장된 모델 불러오기
model = tf.keras.models.load_model("model.h5")

# 3. API 엔드포인트 설정
@app.route('/predict', methods=['POST'])
def predict():
    data = request.json  # JSON 데이터 받기
    x_value = float(data["x"])  # 입력 값 가져오기

    # 모델 예측
    prediction = model.predict(np.array([[x_value]]))[0][0]

    return jsonify({"x": x_value, "predicted_y": prediction})

# 4. 서버 실행
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

🏃 서버 실행

python app.py

💡 설명

  • model.h5 파일을 불러와 예측을 수행하는 API 생성
  • /predict 엔드포인트에서 JSON 데이터를 받아 예측 결과 반환

🔍 API 테스트 (Postman 또는 curl 사용)

curl -X POST http://localhost:5000/predict -H "Content-Type: application/json" -d '{"x": 6}'

✅ 응답 예시

{
  "x": 6,
  "predicted_y": 12.0
}

3️⃣ 웹에서 API 호출하기

이제 HTML과 JavaScript를 사용하여 웹에서 API를 호출해봅니다.

📝 index.html

<!DOCTYPE html>
<html lang="ko">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>TensorFlow API 예제</title>
</head>
<body>
    <h2>숫자를 입력하면 y 값을 예측</h2>
    <input type="number" id="inputX" placeholder="숫자 입력">
    <button onclick="predict()">예측하기</button>
    <p id="result"></p>

    <script>
        function predict() {
            let x = document.getElementById("inputX").value;
            
            fetch("http://localhost:5000/predict", {
                method: "POST",
                headers: { "Content-Type": "application/json" },
                body: JSON.stringify({ x: x })
            })
            .then(response => response.json())
            .then(data => {
                document.getElementById("result").innerText =
                    `입력값: ${data.x}, 예측값: ${data.predicted_y}`;
            })
            .catch(error => console.error("Error:", error));
        }
    </script>
</body>
</html>

💻 웹 실행 방법

  1. index.html을 브라우저에서 열기
  2. 숫자를 입력하고 "예측하기" 버튼을 누르면 Flask API 호출됨
  3. 예측 결과가 웹 화면에 출력됨

🎯 마무리

완성된 구조
1. TensorFlow 모델을 학습하고 저장 (model.h5)
2. Flask API 서버를 만들어 예측 기능 제공 (/predict)
3. 웹에서 API를 호출하여 예측 결과를 표시

profile
FRONT_END_DEVELOMENT

0개의 댓글