๐ ์๋
ํ์ธ์! ์ค๋์ ์ ์ฌ๋ ๋ฐ KNN ๊ณ์ฐ์ ๋น ๋ฅด๊ณ ๊ฐํธํ ์ํํ ์ ์๋ ํจํค์ง!
๐ "Faiss"์ ๋ํด ๊ฐ๋จํ ์ ๊ฐ ์ฌ์ฉํ ๋ถ๋ถ์ ์ ๋ฆฌํ๋ ค ํฉ๋๋ค.
๐ ๊ณต๋ถํ๋ ๊ฐ๋
๋ค์ ํ์ฉํด์, ์๋ Kaggle ๋ํ์ ์ ์ถํ ์ฝ๋๋ฅผ ์์ฑ์ค์ด์์ต๋๋ค!
๐ ์๋ ๋ํ๋ ๋๊ณ ๋์ Fin(๋ฑ์ง๋๋ฌ๋ฏธ)๋ฅผ ํตํด ๊ฐ์ฒด๋ฅผ ๋ถ๋ณํ๋ Task์
๋๋ค.
๐ง ๋ฐ์ดํฐ์
์์ฒด๊ฐ ๋๊ณ ๋ ๊ฐ์ฒด๋ง๋ค ๋ช์ฅ ๋์ง ์๊ณ , class์๋ ์๋นํ ๋ง๊ธฐ ๋๋ฌธ์ ์ผ์ข
์ "Face Recognition Task" ๋ก ์ ๊ทผํ์ฌ ํด๊ฒฐํ๋ ์๋ฃจ์
๋ค์ด ๋ง์์ต๋๋ค.
โจ ์ด์, ์ ๋ Face Recognition์ ํจ๊ณผ๊ฐ ์ข์ ArcFace๋ฅผ ์จ๋ณด์! ๋ผ๋ ํ๋ฆ์ด ๋์์ฃ :)
๐ ๊ทธ๋ ์ต๋๋ค. ์ด์ ์ ์ ๋ฆฌํ๋ ์๋ ๋ ํฌ์คํ
์
๋๋ค!
๐ ViT๋ ๋ชจ๋ธ ์ํคํ
์ฒ๋ก, ArcFace๋ loss function์ผ๋ก ํ์ฉ๋์์ต๋๋ค.
๐ถ ViT
https://velog.io/@gtpgg1013/pytorch-Image-Classification-Using-ViT
๐ ArcFace
https://velog.io/@gtpgg1013/%EB%85%BC%EB%AC%B8%EB%A6%AC%EB%B7%B0-ArcFace-Additive-Angular-Margin-Loss-for-Deep-Face-Recognition
๊ทธ๋ฌ๋!
ํํซ.. ๊ฑฐ์ฐธ ๐คฃ
์ ๊ฐ ์์ ์ ํ์์ ์์ฃผ ์ฌ์ฉํ๋ Softmax layer์์ ์ต๋๊ฐ์ ๋ฝ์ ๊ฒฐ๊ด๊ฐ์ ๋ง๋ค๋ ๋ฐฉ์์ด ์๋๊ธฐ ๋๋ฌธ์,
๊ฒฐ๊ณผ๋ฅผ ์ ์ถํ๊ธฐ ์ํด ์ถ๊ฐ์ ์ธ ํ์ฒ๋ฆฌ๊ฐ ํ์ํ ์ํฉ์ด ๋ฐ์ํ์์ต๋๋ค.
(์ฐธ๊ณ )
๐ ๋ณดํต ์ผ๋ฐ์ ์ธ Classification Task๋ softmax layer์์ ๊ฐ์ ๋ฝ์์, ์ด๋ฅผ argmax๋ฅผ ์ทจํ ๊ฐ์ ๊ฒฐ๊ณผ index๋ก ์ฝ๊ฒ ์ฌ์ฉํ ์ ์์ต๋๋ค.
๐ค ํ์ง๋ง, ArcFace loss function์ ํ์ฉํ ์ถ๋ก ๊ฒฐ๊ณผ๋ ๊ฐ image๋ค ํน์ dim์ผ๋ก Embeddingํ๊ธฐ ๋๋ฌธ์, ํ ์คํธ ๋ฐ ํ์ต ๋ฐ์ดํฐ๋ค์ ๋ชจ๋ ๋ชจ๋ธ์ ํ์ฉํ์ Embedding์ ๋ง๋ค๊ณ , ๊ฐ Embedding์ ์ ์ฌ๋๋ฅผ ํ์ฉํ์ฌ ๊ฐ์ฒด ์ธ์์ ์ํํฉ๋๋ค.
๐ ๋ฌผ๋ก ์ ์ฌ๋๋ฅผ ๊ตฌํ๋ for loop๋ฅผ ๊ทธ๋ฅ ์ง๋ ๋ฉ๋๋ค๋ง...
๐ ๋ถ๋ช
๋๊ฐ ์ข์ ๊ฒ์ ๋ฏธ๋ฆฌ ๋ง๋ค์ด ๋์์ ๊ฒ์ด๋ผ๊ณ ์๊ฐํ๊ณ , ์ฐธ๊ณ ํ ์๋ฃจ์
์์ ์ฐพ๊ฒ ๋์์ต๋๋ค!
github : https://github.com/facebookresearch/faiss
wiki : https://github.com/facebookresearch/faiss/wiki
โจ Faiss๋ Facebook์์ ๋ง๋ ํจ์จ์ ์ธ ์ ์ฌ๋ ๊ฒ์ ๋ฐ clustering์ ์ํด์ ๋ง๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก์, C++๋ก ์ง์ฌ์์ผ๋ฉฐ, GPU ํ์ฉ๊น์ง ๊ฐ๋ฅํ๊ธฐ ๋๋ฌธ์ ํ์ฉ๋ ๋ฐ ํจ์จ์ฑ์ด ๋๋ค๊ณ ํ๋ค์ :)
๐ ์ ๋ L2 distance์ ํ์ฉํ์ฌ Embedding๊ฐ distance๋ฅผ ๊ตฌํ์์ต๋๋ค :)
pip install faiss-gpu
๐ Faiss๋ index ์์ฑ => index์ db ๋ฑ๋ก => query๋ก db search ์์๋ก ์งํ๋ฉ๋๋ค.
๐ index.search ํจ์์ ์ธ์ k ๊ฐฏ์๋งํผ ์ ์ฌํ ์์๋๋ก ์ฐพ์ต๋๋ค.
๐ ๊ทธ๋ฆฌ๊ณ I๋ db์ index / D๋ query์ ํด๋น index์ db์์ distance ์
๋๋ค.
import numpy as np
import faiss
# ํ์ ๋์ : db
db = np.array(np.random.random((100,32)), np.float32)
# db์ ์ง์ : query
query = np.array(np.random.random((2,32)), np.float32)
# ์ ์ฌ๋ ๊ณ์ฐ
def create_and_search_index(embedding_size, db_embeddings, query_embeddings, k):
# ํน์ embedding size(32)์ faiss index ์์ฑ
index = faiss.IndexFlatL2(embedding_size)
# db ๋ฑ๋ก
index.add(db_embeddings)
# k๊ฐ์ ์ ์ฌํ ๊ฐ search
# I๋ db์ index / D๋ query์ ํด๋น index์ db์์ distance
D, I = index.search(query_embeddings, k=k)
return D, I
D, I = create_and_search_index(32, db, query, 5)
D
>>> array([[2.2877078, 2.4738002, 2.5330005, 2.8156984, 2.8736565],
[2.5150692, 2.5936663, 3.0449693, 3.14107 , 3.198256 ]],
dtype=float32)
# ์ฆ, ์ฒซ๋ฒ์งธ query์ ๊ฐ์ฅ ๊ฐ๊น์ด db์ index๋ 35, 2๋ฒ์งธ query์ ๊ฐ์ฅ ๊ฐ๊น์ด db์ index๋ 10์ด๋ค.
I
>>> array([[35, 10, 83, 85, 55],
[10, 84, 89, 44, 92]])
๐ ์์ ๊ฐ์ด ํ์ฉํ๋ฉด ๊ฐ๋จํ ์ ์ฌ๋ ๊ณ์ฐ์ ํ ์ ์์ต๋๋ค :)
๐ L2 Distance ์ด์ธ์๋ ๋ค์ํ distance๋ฅผ ํ์ฉํ์ฌ ์ ์ฌ๋๋ฅผ ๊ตฌํ ์ ์์ผ๋, ์๋ ํ์ด์ง๋ฅผ ์ฐธ์กฐํ์๋ฉด ๋ฉ๋๋ค.
https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
๐ ๊ฐ๋ ฅํ ๊ธฐ๋ฅ์ผ๋ก ์ ์ฌ๋๋ฅผ ํ๋ฐฉ์ ์ ๋ฆฌํด์ฃผ๋๋ก ๋์์ค ํจํค์ง Faiss์ ๋ํด ์ ๋ฆฌํด๋ณด์์ต๋๋ค.
๐ ํ์คํ ์ถ์ฒ ์์คํ
์ด๋ ์ ์ฌ๋ ๋น๊ต๊ฐ ๋ง์ ํ๋ก์ธ์ค์์ ์ ์ฉํ๊ฒ ์ฐ์ผ ๊ฒ ๊ฐ๋ค๋ ์๊ฐ์ด ๋๋ค์.
๐ฑโ๐ค ๋ค์์๋ GPU ๊ธฐ๋ฅ๋ ํ์ฉํด๋ณด๊ณ , ํ์ํ ์ํฉ์ ๋ง์ถฐ distance๋ ํ์ฉํ ์ ์๋ ๊ธฐํ๊ฐ ๊ณง ์ค์ง์์๊น ์๊ฐ์ด ๋ญ๋๋ค.
๐ ๊ทธ๋ผ, ์ฌ๊ธฐ๊น์ง ์ฝ์ด์ฃผ์
์ ๊ฐ์ฌํฉ๋๋ค. ์ข์ ํ๋ฃจ ๋์ธ์!