최근에 DPR, RAG, RETRO, FiD 등을 보면서 Retrieval에 대한 내용들이 많이 나온다. faiss에 대해 살펴보고자 한다.
faiss는 facebook research에서 개발한, dense vector들의 클러스터링과 유사도를 구할때 사용하는 라이브러리이다. C++로 작성되었으며 python에서 지원된다. 그리고 GPU 상에서도 효율적으로 동작하도록 개발 되었다.
특정한 차원 을 가진 벡터들의 집합 이 있을때, faiss는 데이터 구조들을 램 위에 올려두고 새로운 벡터가 들어 왔을때 거리가 가장 적은 벡터를 계산하는 것을 말한다.
은 유클라디안 거리 를 나타낸다.
faiss에서 data structure는 index라 한하며 객체들을 와 더하기 위한 더하기 메서드를 가진다.
argmin을 구하는 것은 index에 대해서
search
연산을 하는것을 말한다.
공식 문서에서는 conda를 통한 설치를 추천하므로 참고.
pip3 install pytorch faiss-cpu
faiss-gpu
provide cuda-enabled indices
pip3 install pytorch faiss-gpu
faiss에서는 index
라는 개념을 사용한다. index는 데이터베이스 벡터들의 집합을 캡슐화 하고 효율적으로 검색하기 위해 선택적으로 전처리를 할 수도 있다.
faiss에는 여러가지 인덱스 타입들이 있으며, 가장 단순하게 사용할수 있는 알고리즘은 brute-force L2 distance 검색을 하는 IndexFlatL2
이다
모든 인덱스들은 빌드 될때 어떤 벡터 차원 d
에서 연산되는지 정보를 필요로 한다. 그리고 대부분의 인덱스들은 트레이닝 단계를 필요로 한다. 인덱스들의 트레이닝이 필요한 이유는 인덱스를 구성하는 벡터들의 분포를 분석할 필요 때문이다. 하지만 IndexFlatL2
의 경우 학습을 필요로 하지 않는다.
인덱스를 빌드하고 학습할때 add
와 search
두가지 연산이 인덱스에 대해 수행된다.
인덱스에 어떤 벡터 x1를 더하기 위해 add
연산을 사용하고, 학습이 되었는지 나타내는 상태인 is_trained
와 인덱싱된 벡터들의 갯수를 나타내는 ntotal
을 확인 할수 있다.
import faiss
import torch
database_size = 100000
num_queries = 1000
db_vector = torch.randn((database_size,10))
query_vector = torch.randn((num_queries,10))
# Build an index와 index에 벡터 더하기
dimension = 10
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(db_vector.numpy())
print(faiss_index.ntotal)
print(faiss_index.is_trained)
기본적인 검색 연산은 인덱스에서 k-nearest-neighbor
을 통해 할수 있다.
각 쿼리 벡터에 대해서 k개의 근접 이웃 벡터를 검색해준다.
이 때 integer 매트릭스 안에 저장되고, 매트릭스의 사이즈는 (nq, k)이다 . row i 는 쿼리벡터 i에 대한 neighbors의 id들을 담고 있다. neighbor 들은 거리가 증가하는 순서로 정렬 되어 잇으며, search 연산은 (nq, k)의 모양의 squared distance, floating-point 매트릭스를 반환한다.
k = 4 # k-th nearest
Distance, Index = faiss_index.search(db_vector[:5].numpy(), k) # sanity check
print(f'Index: {Index}')
print(f'Distance: {Distance}')
Distance, Index = faiss_index.search(query_vector.numpy(), k) # actual serach
print(f'Index: {Index[:5]}')
print(f'Distance: {Distance[:5]}')
Index:
[[ 0 68887 49722 23151]
[ 1 72911 82286 52959]
[ 2 69650 11083 47297]
[ 3 9952 37800 7229]
[ 4 1455 32235 22630]]
Distance:
[[0. 2.4389632 2.4436438 2.7420993 ]
[0. 0.73145294 1.6739532 1.8013455 ]
[0. 1.1959426 1.6530507 1.7116383 ]
[0. 1.6106968 1.7268802 1.86626 ]
[0. 1.9599087 2.4512682 2.5513403 ]]
각 쿼리의 근접 이웃은 벡터의 인덱스이고, 대응되는 거리는 0이다. 그리고 각 row에서 거리의 크기는 증가한다.
Index:
[[73587 2746 15265 96434]
[98388 13550 93912 92610]
[97530 93498 16607 98168]
[52308 24908 70869 20824]
[44597 35140 7572 4596]]
Distance:
[[1.3511505 1.7502174 1.9366131 2.090292 ]
[1.8951168 2.255947 2.2716331 2.2864761 ]
[0.89173985 1.0351162 1.4841261 1.5248547 ]
[1.0868998 1.3354883 1.3459816 1.419671 ]
[1.2155628 1.360589 1.4838543 1.7606869 ]]
actual search에서 나온 결과는 각 쿼리와 유사한 근접 이웃의 인덱스와 그 인덱스와의 거리르 가까운 순으로 보여준다.