faiss 해결

JinYoungMo·2024년 4월 18일
import numpy as np
import faiss
import torch
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer, models
from sklearn.metrics.pairwise import cosine_similarity
from konlpy.tag import Okt
import re
import os
import psycopg2

# 전역 변수로 Faiss 인덱스와 문장 모델을 선언합니다.
index = None
sentence_model = None

# 데이터베이스 연결 설정
DATABASE_URL = "postgresql://postgres:postgres@localhost:5432/postgres"

def get_db_connection():
    return psycopg2.connect(DATABASE_URL)

app = Flask(__name__)

# CUDA 사용 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Okt 객체 초기화
okt = Okt()

# 불용어 리스트
stopwords = ['의', '가', '이', '은', '들', '는', '좀', '잘', '걍', '과', '도', '를', '으로', '자', '에', '와', '한', '하다']

def clean_text_with_okt(text):
    if text is None:
        return ''
    
    text = text.strip()
    text = re.sub(r'\[.*?\]', '', text)
    text = re.sub(r'\s+', ' ', text)

    morphs = okt.pos(text)
    cleaned_words = [word for word, pos in morphs if word not in stopwords and pos in ['Noun', 'Verb', 'Adjective']]
    cleaned_text = ' '.join(cleaned_words)

    return cleaned_text

def initialize_app():
    global index, sentence_model
    sentence_embeddings = np.load('files/modelembeddings.npy').astype('float32')
    dimension = 768
    index = faiss.IndexFlatL2(dimension)
    index.add(sentence_embeddings)

    model_path = 'sentence_model.pt'
    if os.path.exists(model_path):
        model_state = torch.load(model_path, map_location=device)
        word_embedding_model = models.Transformer("snunlp/KR-Medium")
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                       pooling_mode_mean_tokens=True,
                                       pooling_mode_cls_token=False,
                                       pooling_mode_max_tokens=False)
        sentence_model = SentenceTransformer(modules=[word_embedding_model, pooling_model]).to(device)
        sentence_model.load_state_dict(model_state)
    else:
        word_embedding_model = models.Transformer("snunlp/KR-Medium")
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                       pooling_mode_mean_tokens=True,
                                       pooling_mode_cls_token=False,
                                       pooling_mode_max_tokens=False)
        sentence_model = SentenceTransformer(modules=[word_embedding_model, pooling_model]).to(device)
        sentence_model.eval()
        torch.save(sentence_model.state_dict(), model_path)

@app.route('/find_cases', methods=['POST'])
def find_cases():
    global index, sentence_model
    user_input = request.json.get('user_input', '')
    cleaned_input = clean_text_with_okt(user_input)

    # 사용자 입력을 임베딩으로 변환
    user_keyword_embedding = sentence_model.encode([cleaned_input], convert_to_tensor=False).astype('float32')
    print(cleaned_input)

    # Faiss를 사용하여 가장 유사한 임베딩 찾기
    k = 5  # 상위 5개 결과
    D, I = index.search(np.array(user_keyword_embedding), k)  # 거리(D)와 인덱스(I) 반환

    results = []
    conn = get_db_connection()
    try:
        with conn.cursor() as cur:
            for i, idx in enumerate(I[0]):  # 상위 5개 인덱스에 대해서
                cur.execute("SELECT 전문 FROM caseoflaw WHERE index = %s", (int(idx),))
                full_text = cur.fetchone()[0]
                results.append({"index": int(idx), "similarity_score": float(D[0][i]), "full_text": full_text})
    finally:
        conn.close()

    return jsonify(results)


if __name__ == '__main__':
    initialize_app()
    app.run(debug=True, port=5001)
profile
blockchain core & payments and stable coins

0개의 댓글