LangChain X MongoDB Atlas

fortunetiger·2023년 8월 29일
0
post-thumbnail

LangChain을 처음 사용해 보면서 빠르게 진행했던 프로젝트이다 보니 해당 포스트의 방법이 최적의 방법이 아닐 수 있습니다. 그러나 프로젝트를 진행하면서 관련 자료가 많지 않았던 것 같아 뒤늦게라도 아카이빙을 겸해 작성합니다.

📂 Intro

웹으로 챗봇 데모 구현을 진행하면서, 다음과 같은 점을 고려해야 했습니다.

  1. 사용자 세션 유지 및 세션을 기반으로 대화 맥락을 반영할 것
  2. 멀티세션 대화가 가능할 것
  3. 웹서버상에서 동작하면서 개별 사용자의 요청을 처리할 수 있을 것

이를 위해 세션 정보와 메시지의 embedding을 DB에 저장하고, 입력된 메시지 내용을 바탕으로 과거 대화 중에서 vector search를 수행해 유사도 점수 상위 2개의 대화를 맥락으로 반영하기로 했습니다.

한편 LangChain에서는 기본적으로 MongoDB Atlas와의 연동을 위한 각종 클래스들을 제공하고 있습니다. 그럼에도 불구하고 제가 머리를 싸맸던 이유는, 기본 클래스를 사용해서 DB에 데이터를 넣고 읽을 때 대화 내용 외의 다른 정보를 사용하기 어려웠기 때문입니다.

그래서 다음과 같은 문제를 해결했고, 어떤 방법을 사용했는지 정리하고자 합니다.

  1. 대화 내용을 메모리(DB)에 저장할 때 메시지 텍스트와 embedding 외에 부가적인 정보(세션 정보, 메시지 생성 타임스탬프 등)를 함게 저장하도록 하기
  2. 유사도 검색 전 세션 정보로 메시지를 필터링하고 타임스탬프에 따른 가중치 부여하기

📂 MongoDB Atlas

MongoDB Atlas의 Vector Search Index를 사용해 Retrieval을 진행하기 위해 클러스터를 세팅해 봅시다.

MongoDB Atlas database 만들기

https://www.mongodb.com/atlas/database
무료 티어(Shared)에서도 Vector Search Index 사용이 가능합니다. 저는 실제 서비스가 아니라 데모용이므로, M0 Sandbox (General)로 생성해 사용했습니다. 클라우드 제공자와 리전은 원하시는 옵션으로 선택하시면 됩니다. 저는 Google Cloud, 리전은 서울을 선택했습니다.

Security 설정

  • Database Access
    DB 연결 시 사용할 Username과 Password를 설정합니다. Role을 설정해서 용도와 접근범위에 따라 세분화하여 사용하는 것이 좋습니다. 유저 정보는 되도록 하드코딩하지 않고, 공개되지 않도록 주의합니다.

    참고
    Configure Database Deployment Authentication and Authorization

  • Network Access
    DB 접근이 허용되는 인바운드 주소를 설정합니다. 0.0.0.0/0은 되도록 설정하지 않습니다.

Search Index 설정

Vector Search 기능을 사용하기에 앞서 Search Index를 설정합니다.
데이터베이스 상세 화면에서 Search탭으로 진입 한 후 오른쪽 상단의 CREATE INDEX버튼을 클릭합니다.

JSON Editor를 선택합니다.

Vector Search를 수행할 Collection을 선택한 후 다음과 같이 작성합니다. 상세 내용은 필요에 따라 커스텀합니다. 아래 Search Index는 필드 embeddingtimestamp에 대해 인덱싱을 설정합니다.

{
  "mappings": {
    "dynamic": true,
    "fields": {
      "embedding": {
        "dimensions": 1536,
        "similarity": "cosine",
        "type": "knnVector"
      },
      "timestamp": {
        "type": "number"
      }
    }
  }
}

참고

📂 LangChain

import

모델 로드 및 embedding 정의를 위해 다음을 선언합니다.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

from langchain.llms import HuggingFacePipeline
from langchain.embeddings.openai import OpenAIEmbeddings

Chain 및 클래스 커스텀을 위해 다음을 선언합니다.

from langchain.chains import ConversationChain
from langchain.prompts import PromptTemplate
from langchain.vectorstores import MongoDBAtlasVectorSearch
from langchain.memory import VectorStoreRetrieverMemory
from langchain.schema.document import Document

typing을 위한 선언입니다.

from typing import Optional, Dict, Any, List

ConversationChain 구성하기

ConversationChain을 사용해서 메시지를 생성하는 Chain을 구성해보겠습니다. memory로는 앞서 구성한 MongoDB Atlas 클러스터를 사용합니다.

참고

모델 로드

먼저 다음과 같이 모델을 로드합니다.

model = AutoModelForCausalLM.from_pretrained(
	# 모델 로드 정보
).to(device='cuda', non_blocking=True)

tokenizer = AutoTokenizer.from_pretrained(''# model_path)
model.eval()
model.config.use_cache = True

pipe = pipeline(
    'text-generation',
    model = model,
    tokenizer = tokenizer
)

local_llm = HuggingFacePipeline(pipeline=pipe)

참고

Vector Search 구성

Vector Search를 수행할 MongoDB Atlas 클러스터의 Collection을 지정합니다.

from pymongo import MongoClient

embedding_fn = OpenAIEmbeddings(
	openai_api_key= ''#OpenAI API KEY
)

mongodb_client = MongoClient(''# MongoDB Connection String URI)
collection = mongodb_client[''# collection name]

vectorstore = MongoDBAtlasVectorSearch(
	collection, embedding_fn
)

참고

Prompt 선언

프롬프트 템플릿을 선언합니다.

template = ''# 프롬프트 템플릿
prompt = PromptTemplate(
	input_variables=['history', '### 명령어'], template=template
)

참고

VectorStoreRetrieverMemory 커스텀하기

Chain을 실행하는 과정에서 embedding 벡터 저장 시 세션 정보와 타임스탬프 같은 정보도 함께 metadata로 저장할 수 있도록 VectorStoreRetrieverMemory_form_documents를 다음과 같이 오버라이딩하겠습니다.

class CustomVectorStoreRetrieverMemory(VectorStoreRetrieverMemory):

    metadata: Optional[Dict[str, Any]] = None,

    def _form_documents(
        self, inputs: Dict[str, Any], outputs: Dict[str, str]
    ) -> List[Document]:
        """Format context from this conversation to buffer."""
        # Each document should only include the current turn, not the chat history
        filtered_inputs = {k: v for k, v in inputs.items() if k != self.memory_key}
        texts = [
            f"{k}: {v}"
            for k, v in list(filtered_inputs.items()) + list(outputs.items())
        ]
        page_content = "\n".join(texts)
        return [Document(page_content=page_content, metadata=self.metadata)]

ConversationChain 구성

ConversationChain을 선언하기 앞서, 제가 DB에 읽고 쓰고자 하는 데이터의 스키마를 먼저 간략하게 살펴보겠습니다.

FieldTypeDescription
_idObjectID
textString'### 명령어', '### 응답' 쌍으로 구성된 메시지 텍스트
embeddingArray1536 차원
user_idString유저 세션 식별용, 16자리 랜덤 해시
timestampdouble유닉스 타임스탬프, 메시지 생성 시각

_id, text, embedding은 대화 턴이 memory에 추가되는 과정에서 추가되지만, 그 외 필드는 metadata로 전달해서 저장하게 됩니다.

다음과 같이 ConversationChain을 구성합니다.

my_chain = ConversationChain(
	llm=local_llm,
    prompt=prompt,
    memory=CustomVectorStoreRetrieverMemory(
    	retriever=vectorstore.as_retriever(search_kwargs={
        	'k':2,
            'pre_filter': build_pre_filter(user_id, timestamp),
            'post_filter_pipeline': build_post_filter_pipeline()
        }),
        metadata={'user_id': user_id, 'timestamp': timestamp},
   ),
   input_key='### 명령어',
   output_key='### 응답',
   verbose=True
)

build_pre_filter()build_post_filter_pipeline()은 별도로 선언해서 사용했으나 리터럴로 바로 쿼리를 작성해서 넣어도 무방합니다.

위 체인은 실행시에 지정된 컬렉션에서 pre_filter를 적용하여 유사도 점수가 가장 높은 k개의 메시지를 검색한 후 post_filter_pipline으로 전달된 쿼리를 적용하여 retrieval 결과를 반환합니다. 새로 입력된 메시지는 전달된 metadata값을 포함하여 새로 생성된 메시지와 쌍으로 저장합니다.

build_pre_filter() 살펴보기

def build_pre_filter(user_id: str, timestamp: float) -> dict:
    return {
        'compound': {
                    'filter': {
                        'text': {
                            'path': 'user_id',
                            'query': user_id
                            }
                    },
                    'should': {
                        'near': {
                            'origin': timestamp,
                            'path': 'timestamp',
                            'pivot': 10000000
                            }
                    }
        }
    }

세션 id로 메시지를 필터링하고 최신 메시지일수록 가중치를 부여하기 위해 다음과 같이 쿼리를 만들었습니다. 위 쿼리는 user_id와 일치하는 메시지 중에서, near 연산자를 사용해 메시지의 timestamp값이 현재 입력된 메시지의 timestamp값과 가까울수록 유사도 점수에 가중치를 부여합니다.

참고

메시지 생성하기

새 메시지를 question이라고 할 때 다음과 같이 체인을 실행하고 답변 메시지를 생성할 수 있습니다.

input_dict = {'### 명령어': question}
response = my_chain.predict(**input_dict)

📂 마치며

레퍼런스가 정말 없었던 건지, 제대로 찾지 못했던 건지... 프로젝트를 진행하면서 정말 많이 헤매고 고민했습니다. 일단 프로젝트 기간 내에 원하는 대로 동작하는 데모를 구현하는 데에는 성공했지만, 시간이 좀 더 있었다면 더 나은 방법을 찾을 수 있었을지도 모르겠습니다. 🥲

저의 삽질이 여러분에게 작은 힌트나 아이디어가 될 수 있기를 바라며 마칩니다!

관련 링크

0개의 댓글