https://velog.io/@dldmswo1209/FlaskAndroidKotlin-Flask-REST-API-서버를-만들고-Android-Retrofit2-로-서버에서-데이터-가져오기
이전에 쓴 글에서 "Flask 로 AI 모델 연결 문제를 해결할 수 있을까?" 했는데!!! 해결했다!!
전에는 소켓통신을 사용해서 채팅방을 나갔다 오거나, 앱을 껐다 켜면, 서버를 껐다가 다시 켜야되는 1회성 문제가 있었는데, 지금은 서버만 켜져있으면 아무 문제 없이 잘 동작한다.
서버에도 데이터가 잘 전달되는 것을 볼 수 있다.
전체적인 설명을 하자면 파이썬 Flask 로 Rest API 서버를 만들고, GET 요청이 들어오면, GET 요청 파라미터로 텍스트(사용자의 채팅, 안드로이드에서 전달)를 받는다. 받은 텍스트를 토크나이저를 통해서 토큰화 한 후, AI 모델을 통해 맥락에 맞는 적절한 텍스트를 생성하고, 안드로이드에 전달한다.
from flask import Flask, jsonify
from flask_restful import Api
import torch
from model.kogpt2 import DialogKoGPT2
from kogpt2_transformers import get_kogpt2_tokenizer
#경로 설정하기
save_ckpt_path = f"./total_autoregressive_second_1.pth"
#모델 gpu 사용 여부
ctx = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(ctx)
# 저장한 모델 불러오기
checkpoint = torch.load(save_ckpt_path, map_location=device)
#모델 구조 불러오기
model = DialogKoGPT2()
#학습한 모델 불러오기
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Flask 인스턴스 정리
app = Flask(__name__)
api = Api(app)
# 서버로 입력이 들어 오면 입력으로 들어온 텍스트를 토큰화 시키기 위한 토크나이저
tokenizer = get_kogpt2_tokenizer()
@app.route('/echo_call/<param>') #get echo api
def get_echo_call(param):
# param : 입력으로 들어오는 텍스트임
# 입력값을 토큰화
tokenized_indexs = tokenizer.encode(param)
# 모델에 input 으로 넣기 위한 shape 으로 변환
input_ids = torch.tensor([tokenizer.bos_token_id, ] + tokenized_indexs + [tokenizer.eos_token_id]).unsqueeze(0)
# 모델에 넣고, 예측 값을 sample_output 에 저장
sample_output = model.generate(input_ids=input_ids)
# 디코딩을 통해서 안드로이드에 전달할 텍스트로 변환
ans = tokenizer.decode(sample_output[0].tolist()[len(tokenized_indexs) + 1:], skip_special_tokens=True)
ans = ans[:ans.find(".")]
return jsonify({"param": ans}) # 모델의 예측 결과를 JSON 형태로 반환
# 서버를 실행 할 host 와 port 를 지정
if __name__ == '__main__':
app.run(host='ip 주소',port=8000,debug=True)
안드로이드에서 GET 요청을 하기위해서 다음 단계들을 수행 해야 한다.
import com.google.gson.JsonObject
import retrofit2.Call
import retrofit2.http.GET
import retrofit2.http.Path
interface RetrofitAPI {
@GET("/echo_call/{chatText}") // 서버에 GET 요청을 할 주소 입력
fun getAIReply(@Path("chatText") chatText: String) : Call<JsonObject> // 입력으로 chatText 를 서버로 넘기고, 서버에서 답장을 가져오는 메소드
}
GET 요청할 때 파라미터로 넘겨줄 텍스트를 알려줘야한다.
서버에 GET 요청 주소를 입력할 때 /{chatText} 를 작성하고
메소드 파라미터로 @Path("chatText") chatText: String 을 작성한다.
이렇게하면 파라미터가 자동으로 GET 요청 주소({chatText})로 넘어간다.
// 서버에서 가져온 정보를 JSON 형식의 데이터에서 내가 원하는 타입의 데이터로 변환시키기 위한 DTO
class RetrofitDTO {
data class ChatItem(val param: String)
}
lateinit var mRetrofit : Retrofit // 사용할 레트로핏 객체
lateinit var mRetrofitAPI: RetrofitAPI // 레트로핏 api 객체
lateinit var mCallAIReply : retrofit2.Call<JsonObject> // Json 형식의 데이터를 요청하는 객체
private fun setRetrofit() {
// retrofit 으로 가져올 url 을 설정하고 세팅
mRetrofit = Retrofit.Builder()
.baseUrl(getString(R.string.baseUrl))
.addConverterFactory(GsonConverterFactory.create())
.build()
// 인터페이스로 만든 레트로핏 api 요청 받는 것 변수로 등록
mRetrofitAPI = mRetrofit.create(RetrofitAPI::class.java)
}
private val mRetrofitCallback = (object : retrofit2.Callback<JsonObject>{
override fun onResponse(call: Call<JsonObject>, response: Response<JsonObject>) {
// 서버에서 데이터 요청 성공시
val result = response.body()
Log.d("testt", "결과는 ${result}")
var gson = Gson()
val dataParsed1 = gson.fromJson(result, RetrofitDTO.ChatItem::class.java)
val chatItem = ChatItem(dataParsed1.param, TYPE_BOT)
chatItemPushToDB(chatItem)
}
override fun onFailure(call: Call<JsonObject>, t: Throwable) {
// 서버 요청 실패
t.printStackTrace()
Log.d("testt", "에러입니다. ${t.message}")
ServerConnectErrorToast()
}
})
private fun callTodoList(chatText: String){
mCallAIReply = mRetrofitAPI.getAIReply(chatText) // RetrofitAPI 에서 JSON 객체를 요청해서 반환하는 메소드 호출
mCallAIReply.enqueue(mRetrofitCallback) // 응답을 큐에 넣어 대기 시켜놓음. 즉, 응답이 생기면 뱉어낸다.
}