[langchain] VectorStoreRetriever

Jeong-Minju·2024년 8월 2일
0

langchain

목록 보기
3/3

1. VectorStore

vectorstore는 최근 뜨거운 감자인 RAG(Retrieval-Augmented Generation)에서 등장한 요소입니다. 입력 텍스트 임베딩 벡터와 유사도를 비교하는 벡터들을 담고 있는 공간, 이것이 vectorstore(vector DB)죠. langchain에서는 여러 vectorDB들의 extension을 구현해 놓았습니다. 아래에 몇 가지를 적어놓았습니다.

  • faiss
  • pinecone
  • elasticsearch
  • clomaDB
  • ...

더 자세히 vectorstore들의 종류를 확인해보고 싶으시다면 github 링크를 참고해주세요. 대표적인 vectorstore 중 하나인 faiss vectorstore를 기준으로 예제를 진행하겠습니다.

1.1 FAISS vectorstore


from langchain_community.vectorstores.faiss import FAISS

faiss_vectorstore = FAISS(embedding_function, docstore, index_to_docstore_id, ...)

FAISS를 선언하는 방법은 위 방법 외에도 from_...(...)함수들을 이용하는 방법이 존재합니다.

from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter

# 예시 text파일을 로드
loader = TextLoader("../../how_to/state_of_the_union.txt")
# langchain_core.documents.Document로 변환 및 가공
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)

# vectorstore에 적재할 vector생성용 embedding model 선언
embeddings = OpenAIEmbeddings()
# vectorstore생성
db = FAISS.from_documents(docs, embeddings)

위 예시는 from_documents(...)함수를 사용해 vectorstore를 생성하는 예제입니다.

2. VectorRetriever

기본적으로 각 vectorstore들은 아래처럼 langchain_core.vectorstore.base.VectorStore 클래스를 상속받아 구성됩니다.

class FAISS(VectorStore):
    """`Meta Faiss` vector store.

    To use, you must have the ``faiss`` python package installed.

    Example:
        .. code-block:: python

            from langchain_community.embeddings.openai import OpenAIEmbeddings
            from langchain_community.vectorstores import FAISS

            embeddings = OpenAIEmbeddings()
            texts = ["FAISS is an important library", "LangChain supports FAISS"]
            faiss = FAISS.from_texts(texts, embeddings)

    """

VectorStore 클래스는 간단히 설명하자면 document들을 적재, 추가, 검색 등의 기능을 담고 있는 클래스입니다.

class VectorStore(ABC):
...
  def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
          """Return VectorStoreRetriever initialized from this VectorStore.

          Args:
              **kwargs: Keyword arguments to pass to the search function.
                  Can include:
                  search_type (Optional[str]): Defines the type of search that
                      the Retriever should perform.
                      Can be "similarity" (default), "mmr", or
                      "similarity_score_threshold".
                  search_kwargs (Optional[Dict]): Keyword arguments to pass to the
                      search function. Can include things like:
                          k: Amount of documents to return (Default: 4)
                          score_threshold: Minimum relevance threshold
                              for similarity_score_threshold
                          fetch_k: Amount of documents to pass to MMR algorithm
                              (Default: 20)
                          lambda_mult: Diversity of results returned by MMR;
                              1 for minimum diversity and 0 for maximum. (Default: 0.5)
                          filter: Filter by document metadata

          Returns:
              VectorStoreRetriever: Retriever class for VectorStore.

          Examples:

          .. code-block:: python

              # Retrieve more documents with higher diversity
              # Useful if your dataset has many similar documents
              docsearch.as_retriever(
                  search_type="mmr",
                  search_kwargs={'k': 6, 'lambda_mult': 0.25}
              )

              # Fetch more documents for the MMR algorithm to consider
              # But only return the top 5
              docsearch.as_retriever(
                  search_type="mmr",
                  search_kwargs={'k': 5, 'fetch_k': 50}
              )

              # Only retrieve documents that have a relevance score
              # Above a certain threshold
              docsearch.as_retriever(
                  search_type="similarity_score_threshold",
                  search_kwargs={'score_threshold': 0.8}
              )

              # Only get the single most similar document from the dataset
              docsearch.as_retriever(search_kwargs={'k': 1})

              # Use a filter to only retrieve documents from a specific paper
              docsearch.as_retriever(
                  search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
              )
          """
          tags = kwargs.pop("tags", None) or [] + self._get_retriever_tags()
          return VectorStoreRetriever(vectorstore=self, tags=tags, **kwargs)

그중 내부 as_retriever(...)라는 함수가 있습니다. 해당 함수를 통해 우리는 선언해둔 vectorstore에서 벡터검색을 수행하는 VectorStoreRetriever 클래스를 리턴해 사용하게 됩니다.

2.1 VectorStoreRetriever & BaseRetriever

앞선 파트에서 우리는 VectorStoreRetriever를 통해 vectorstore를 retriever로 사용할 수 있었습니다. 이제 더욱 자세히 VectorStoreRetriever에 대해 알아보도록 하겠습니다.

class VectorStoreRetriever(BaseRetriever):
    """Base Retriever class for VectorStore."""

    vectorstore: VectorStore
    """VectorStore to use for retrieval."""
    search_type: str = "similarity"
    """Type of search to perform. Defaults to "similarity"."""
    search_kwargs: dict = Field(default_factory=dict)
    """Keyword arguments to pass to the search function."""
    allowed_search_types: ClassVar[Collection[str]] = (
        "similarity",
        "similarity_score_threshold",
        "mmr",
    )

    class Config:
        """Configuration for this pydantic object."""

        arbitrary_types_allowed = True

    @root_validator(pre=True)
    def validate_search_type(cls, values: Dict) -> Dict:
        """Validate search type.

        Args:
            values: Values to validate.

        Returns:
            Values: Validated values.

        Raises:
            ValueError: If search_type is not one of the allowed search types.
            ValueError: If score_threshold is not specified with a float value(0~1)
        """
        search_type = values.get("search_type", "similarity")
        if search_type not in cls.allowed_search_types:
            raise ValueError(
                f"search_type of {search_type} not allowed. Valid values are: "
                f"{cls.allowed_search_types}"
            )
        if search_type == "similarity_score_threshold":
            score_threshold = values.get("search_kwargs", {}).get("score_threshold")
            if (score_threshold is None) or (not isinstance(score_threshold, float)):
                raise ValueError(
                    "`score_threshold` is not specified with a float value(0~1) "
                    "in `search_kwargs`."
                )
        return values

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        if self.search_type == "similarity":
            docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
        elif self.search_type == "similarity_score_threshold":
            docs_and_similarities = (
                self.vectorstore.similarity_search_with_relevance_scores(
                    query, **self.search_kwargs
                )
            )
            docs = [doc for doc, _ in docs_and_similarities]
        elif self.search_type == "mmr":
            docs = self.vectorstore.max_marginal_relevance_search(
                query, **self.search_kwargs
            )
        else:
            raise ValueError(f"search_type of {self.search_type} not allowed.")
        return docs

    async def _aget_relevant_documents(
        self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
    ) -> List[Document]:
        if self.search_type == "similarity":
            docs = await self.vectorstore.asimilarity_search(
                query, **self.search_kwargs
            )
        elif self.search_type == "similarity_score_threshold":
            docs_and_similarities = (
                await self.vectorstore.asimilarity_search_with_relevance_scores(
                    query, **self.search_kwargs
                )
            )
            docs = [doc for doc, _ in docs_and_similarities]
        elif self.search_type == "mmr":
            docs = await self.vectorstore.amax_marginal_relevance_search(
                query, **self.search_kwargs
            )
        else:
            raise ValueError(f"search_type of {self.search_type} not allowed.")
        return docs

    def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
        """Add documents to the vectorstore.

        Args:
            documents: Documents to add to the vectorstore.
            **kwargs: Other keyword arguments that subclasses might use.

        Returns:
            List of IDs of the added texts.
        """
        return self.vectorstore.add_documents(documents, **kwargs)

    async def aadd_documents(
        self, documents: List[Document], **kwargs: Any
    ) -> List[str]:
        """Async add documents to the vectorstore.

        Args:
            documents: Documents to add to the vectorstore.
            **kwargs: Other keyword arguments that subclasses might use.

        Returns:
            List of IDs of the added texts.
        """
        return await self.vectorstore.aadd_documents(documents, **kwargs)

VectorStoreRetriever는 langchain.retrievers.BaseRetriever를 상속받으며 "from langchain_core.vectorstore.base import VectorStoreRetriever"로 import를 할 수 있습니다.

대표적으로 아래의 기능들을 담고있습니다.

  • vectorstore 내에 document 추가
  • vectorstore에서 유사한 document들을 검색

retriever의 중심 기능인 document검색 부분을 자세히 살펴보시죠.

def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        if self.search_type == "similarity":
            docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
        elif self.search_type == "similarity_score_threshold":
            docs_and_similarities = (
                self.vectorstore.similarity_search_with_relevance_scores(
                    query, **self.search_kwargs
                )
            )
            docs = [doc for doc, _ in docs_and_similarities]
        elif self.search_type == "mmr":
            docs = self.vectorstore.max_marginal_relevance_search(
                query, **self.search_kwargs
            )
        else:
            raise ValueError(f"search_type of {self.search_type} not allowed.")
        return docs

self.search_type에 따라 각 다른 vector search를 수행합니다. 각 search방법들은 각각 다른 vectorstore 구현되기에 langchain_community.vectorstores.faiss.FAISS와 같이 langchain에서 wrapping해 구현한 클래스들안에 작업되어 있습니다.

# langchain_community.vectorstores.faiss.FAISS 내 함수
	...
	def similarity_search(
        self,
        query: str,
        k: int = 4,
        filter: Optional[Union[Callable, Dict[str, Any]]] = None,
        fetch_k: int = 20,
        **kwargs: Any,
    ) -> List[Document]:
        """Return docs most similar to query.

        Args:
            query: Text to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
            fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
                      Defaults to 20.

        Returns:
            List of Documents most similar to the query.
        """
        docs_and_scores = self.similarity_search_with_score(
            query, k, filter=filter, fetch_k=fetch_k, **kwargs
        )
        return [doc for doc, _ in docs_and_scores]

FAISS retriever에서 유사 document검색 예제를 마지막으로 이번 글을 마치도록 하겠습니다.

from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter

# 예시 text파일을 로드
loader = TextLoader("../../how_to/state_of_the_union.txt")
# langchain_core.documents.Document로 변환 및 가공
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)

# vectorstore에 적재할 vector생성용 embedding model 선언
embeddings = OpenAIEmbeddings()
# vectorstore 생성
db = FAISS.from_documents(docs, embeddings)

# retriever 생성
faiss_retriever = db.as_retriever()

# 유사 문서 검색
faiss_retriever.invoke(input="What is state of union?")

retriever로 검색을 수행시, 다른 langchain의 Runnable 객체처럼 invoke로도 사용이 가능합니다.

profile
RAG를 좋아하는 사람입니다 :)

0개의 댓글