
Flamingo: a Visual Language Model for Few-Shot Learning
Advances in neural information processing systems 2022
Flamingo๐ฆฉ๋ผ๋ ๋น์ฃผ์ผ ์ธ์ด ๋ชจ๋ธ์ ์๊ฐ
์ด ๋ชจ๋ธ์ ์ฃผ์ด์ง ์์์ ์์๋ก ์๋ก์ด ์์ ์ ๋น ๋ฅด๊ฒ ํ์ตํ ์ ์๋ ๋ฅ๋ ฅ์ ๊ฐ์ถ๊ณ ์์ผ๋ฉฐ, ํ ์คํธ์ ์ด๋ฏธ์ง(๋น๋์ค)๋ฅผ ๊ฒฐํฉํ์ฌ ์๋ก์ด ํ ์คํธ๋ฅผ ์์ฑํ ์ ์๋๋ก ํ์ตํจ. ์ด ๋ชจ๋ธ์ ๋ค์ํ ๋ฒค์น๋งํฌ์์ ํ์ํ ์ฑ๊ณผ๋ฅผ ๋ณด์์ผ๋ฉฐ, ์๋ก์ด ๋ฐ์ดํฐ ์์ด๋ ๊ฐ๋จํ ์์๋ค๋ก ๋์ ์ฑ๋ฅ์ ๋ฌ์ฑํ ์ ์๋ค๋ ์ ์ด ์ฃผ์ ๊ธฐ์ฌ์
ํ์ฌ ์ปดํจํฐ ๋น์ ์์ ์๋ก์ด ์์ ์ ๋น ๋ฅด๊ฒ ๋ฐฐ์ฐ๋ ๊ฒ์ ๋ํ ์ฐ๊ตฌ๊ฐ ์ ์ํ๋๊ณ ์์ง๋ง, ๋๋ถ๋ถ์ ๋ฐฉ๋ฒ๋ค์ ์ฌ์ ํ ๋๊ท๋ชจ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํด ์ฌ์ ํ์ตํ ํ, ๊ด์ฌ ์๋ ์์ ์ ๋ง๊ฒ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ฐฉ์์ ์์กดํ๊ณ ์์. ๊ทธ๋ฌ๋ ์ฑ๊ณต์ ์ธ ๋ฏธ์ธ ์กฐ์ ์ ์ํด์๋ ์์ฒ ๊ฐ์ ์ฃผ์์ด ๋ฌ๋ฆฐ ๋ฐ์ดํฐ ํฌ์ธํธ๊ฐ ํ์ํ๋ฉฐ, ์์ ๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ์กฐ์ ์ด ํ์ํ ๋ฟ๋ง ์๋๋ผ ์์์ด ๋ง์ด ์๋ชจ๋จ.
์ต๊ทผ์๋ ๋์กฐ์ ๋ชฉ์ (contrastive objective) ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์ฌ ํ์ต๋ ๋ค์ค ๋ชจ๋ฌ ๋น์ -์ธ์ด ๋ชจ๋ธ์ด ๋ฏธ์ธ ์กฐ์ ์์ด๋ ์๋ก์ด ์์ ์ ์ ๋ก์ท ์ ์์ ๊ฐ๋ฅํ๊ฒ ํจ. ํ์ง๋ง ์ด๋ฌํ ๋ชจ๋ธ์ ํ ์คํธ์ ์ด๋ฏธ์ง ๊ฐ ์ ์ฌ์ฑ ์ ์๋ง์ ์ ๊ณตํ ์ ์์ด ๋ถ๋ฅ์ ๊ฐ์ ์ ํ๋ ์ฌ์ฉ ์ฌ๋ก์๋ง ์ ์ฉ๋ ์ ์๋ ํ๊ณ๊ฐ ์์.
์ด๋ค์ ์ธ์ด ์์ฑ ๋ฅ๋ ฅ์ด ๋ถ์กฑํ์ฌ ์บก์ ๋์ด๋ ๋น์ฃผ์ผ ์ง๋ฌธ ์๋ต๊ณผ ๊ฐ์ ์์ ์๋ ์ ํฉํ์ง ์์. ์ด๋ฅผ ๊ทน๋ณตํ๊ธฐ ์ํด ์๊ฐ์ ์ธ ๋ด์ฉ์ ๋ํ ์ธ์ด ์์ฑ์ ๋ํ ํ๊ตฌ๋ ์ด๋ฃจ์ด์ก์ง๋ง, ์์ง ์ข์ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ์ง ๋ชปํ๊ณ ์์.
๋ณธ ๋ ผ๋ฌธ์ ์ด๋ฌํ ํ๊ณ๋ฅผ ๊ทน๋ณตํ๊ณ , ๋ค์ํ ๋น์ ๋ฐ ์ธ์ด ์์ ์์ ์์์ ์์๋ง์ผ๋ก ํ์ตํ ์ ์๋ ๋น์ฃผ์ผ ์ธ์ด ๋ชจ๋ธ์ธ Flamingo๋ฅผ ์๊ฐํจ. ์ด ๋ชจ๋ธ์ ๋ช ๊ฐ์ง ์ ๋ ฅ/์ถ๋ ฅ ์์๋ง์ผ๋ก ๋ค์ํ ์์ ์ ์ ์ฉ๋ ์ ์์ผ๋ฉฐ, ๋ฏธ์ธ ์กฐ์ ์์ด๋ ์ฌ๋ฌ ์์ ์์ ์๋ก์ด ์ฑ๋ฅ์ ๋ณด์ฌ์ค.
(๋ฌด์จ ๋๋ฌผ์ ๋ํ ๋๋ต)

(๋๋ฌผ ์์ ๋ํ ๋๋ต)

(์์ ๋ด์ฉ ์ค๋ช
)


Flamingo๋ ์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด Vision encoder, Language model(LM) block๋ก ๋๋๋ฉฐ, Perceiver Resampler(์ฐ๋ณด๋ผ์), Gated XATTN-Dense๊ฐ ์๋ ๊ฒ์ด ํน์ง
Vision encoder : CLIP text-image contrastive learning์ผ๋ก ํ์ต๋ ๋น์ ์ธ์ฝ๋๋ฅผ ๋ถ๋ฌ์ ์ฌ์ฉํจ
Language model : Large text corpus๋ก ํ์ต๋ Chinchilla๋ฅผ ์ฌ์ ํ์ต ๋ถ๋ฌ์ ์ฌ์ฉํจ
Perceiver Resampler : ๋น์ ์ธ์ฝ๋์ ์ถ๋ ฅ์ ๊ณ ์ ๋ ํฌ๊ธฐ๋ก mappingํ๋ ์ญํ (๊ณ ์ ํฌ๊ธฐ๋งํผ ์์์ ธ์ ๊ณ์ฐ ํจ์จ์ )
Gated XATTN-Dense : Query, Key, Value๋ฅผ ์ ๋ ฅ ๋ฐ์ ํด๋น ์ ๋ณด๊ฐ ๊ฐ๋ฏธ๋ ๋ฒกํฐ๋ฅผ ์ถ๋ ฅํ๋ ๋ ์ด์ด
๋ค์์ผ๋ก Perceiver Resampler์ Gated XATTN-Dense์ ๋ํด์ ์ค๋ช

์ฐ์ Perceiver Resampler๋ ๋ณธ ๋ ผ๋ฌธ(2022)์์ ์๋กญ๊ฒ ์ ์ํ ๊ฑด ์๋๊ณ 2021๋ Perceiver: General Perception with Iterative Attention์์ ์ ์๋จ. ๋ณธ ๋ ผ๋ฌธ์ Cross Attention์ ์ ๋ ฅํ๊ธฐ ์ํ ๋ฆฌ์ํ๋ง ๋ฐฉ๋ฒ์ ์ฌ์ฉํ ๊ฒ์ด๋ผ ๋ณด๋ฉด ๋จ.
์ฐ์ ์ค๋ช ์ ์, Flamingo ๋ชจ๋ธ์์๋ ์ฐ์ Vision Feature์ Text Feature๋ฅผ ๊ฐ์ด ์ฐ์ฐํด์ฃผ์ด ํ๋ค๋ ๊ฑธ ์๊ฐํด๋ณด์.
์ด๋ Vision Feature๋ Text Feature๋ณด๋ค ์ผ๋ฐ์ ์ผ๋ก ํจ์ฌ ํฐ ์ฐจ์์ ๊ฐ๊ธฐ ๋๋ฌธ์ ์ด ๋์ ๋์์ ์ฐ์ฐํ๊ธฐ ์ํด์๋ Vision Feature์ ์ฐจ์์ Text Feature์ ๋ง๊ฒ ์ถ์ํด์ค ํ์๊ฐ ์์. (์ ๊ทธ๋ผ ๋ฒกํฐ์ ์ฐจ์๋ ์ ๋ง๊ณ ์ฐ์ฐ ์์ฒด๊ฐ ๋ถ๊ฐ๋ฅํ๋ค)
์ ๊ทธ๋ฆผ์์๋ ์๋์ชฝ์์ ๊ณ ์ฐจ์ ๋ฒกํฐ์ ํด๋นํ๋ Vision Feature๋ฅผ ์ด๋ก์ ๊ณ์ด๋ก ํํํ๊ณ ์๋๋ฐ, ์ด์ ์ด Vision Feature๋ฅผ ์ ์ฐจ์์ผ๋ก ์ถ์ํด์ผ ํจ. (์ด๋์ ํ๊ฒ ๋ฒกํฐ(Latent)๋ฅผ ํ์์ผ๋ก ํํ)
Query๋ก ๋ค์ด๊ฐ๋ ์ ์ฐจ์ ๋ฒกํฐ๋ Learned Latent Vector(๊ทธ๋ฆผ์ ์์น์ ๋ณด๋ฅผ ์ ๊ฒฝ๋ง์ ์ ๋ ฅํด์ ๋์จ ๋ฒกํฐ)๋ฅผ ์ฌ์ฉํจ. ์ด๋ ๊ฒ Learned Latent Vector๋ Query๋ก, Vision Feature๋ฅผ Key, Value๋ก ํ์ฌ Cross Attention์ ์ํ

Gated XATTN-Dense๋ ์์ Vision ์ ๋ณด์ Language ์ ๋ณด๋ฅผ ์ตํฉ(Cross Attention)ํ์ฌ Language๊ฐ ๊ฐ๋ฏธ๋ ๋น์ ๋ฒกํฐ๋ฅผ ์ป๋๋ค๊ณ ์๊ฐํ๋ฉด ํธํจ. ์ฌ๊ธฐ์ ์ ๋ ฅ์ผ๋ก ์๊ตฌ๋๋ Query, Key, Value๋ฅผ ์ ํด์ผํ๋ค. ๋ณธ ๋ชจ๋ธ์ ๋น์ฃผ์ผ-์ธ์ด ๋ชจ๋ธ์ด๋ฏ๋ก ์ต์ข ์ถ๋ ฅ์ ํ ์คํธ์. ๋ฐ๋ผ์ Query ๋ก๋ Text Feature, ์ ๋ณด๋ฅผ ์ถ๊ฐํ Key, Value๋ก๋ Vision Feature๋ฅผ ์ฌ์ฉ
๋ณธ ๋ ผ๋ฌธ์ ์ฌ์ฉํ๋ ์ฐ์ฐ ์์๋ ์ ๊ทธ๋ฆผ์ ํตํด ํ์ธํ ์ ์์. ์ค๋ช ํ์๋ฉด ๋จผ์ visual feature(x)์ ๋ํด Cross Attention(q=y(language), kv=x) ์ฐ์ฐ ๋ค์ feed forward(FFW) ๋ ์ด์ด๋ฅผ ํตํด weight์ bias๋ฅผ (wx+b)๊ณ์ฐํ๊ณ , ์ดํ ๊ทธ ๊ฐ(y)์ ๋ค์ self attenction(q=y, kv=y)๋ฅผ ์ฐ์ฐ, FFW ๋ ์ด์ด๋ฅผ ๊ฑฐ์ณ ์ต์ข y๋ฅผ ์ถ๋ ฅํจ
๋ณธ ๋ ผ๋ฌธ์์ ์ฌ์ฉ๋ ๋ฐ์ดํฐ ์ ์ ์๋์ ๊ฐ์
์ดํ ์๋ ๋ ์กฐ๊ฑด์ผ๋ก ์คํ ์งํ
๋ณธ ๋ ผ๋ฌธ์ Few Shot Learning์ ๋ํ ๊ฒฐ๊ณผ๋ฅผ ์ํด ์ด 16๊ฐ์ ๋ฒค์น๋งํฌ๋ฅผ ์ฌ์ฉํ์ฌ ์ฑ๋ฅ์ ์ธก์ . ๊ฐ ๋ฌธ์ ๋ณ(๋ฒค์น๋งํฌ)๋ง๋ค ๋ชจ๋ธ์ Fine Tuning ํ์ง ์๊ณ ๋จ์ง ์ฌ์ ํ์ต์ ์๋ฃํ Flamingo์๊ฒ ๋ช ๊ฐ์ง ์์๋ฅผ ์ ๊ณตํ๊ณ ํด๊ฒฐํ๋๋ก ํจ

์คํ ๊ฒฐ๊ณผ ๊ธฐ์กด SOTA ๋ชจ๋ธ๋ค ๋๋น ์ฑ๋ฅ์ด ๊ฑฐ์ ๋ชจ๋ ์ข์๋ค๊ณ ํจ.

์ด๋ฒ์๋ ๋ฌธ์ ๋ค์ ๋ํด ์ถ๊ฐ ํ์ต์ ํ์ฌ(Fine-tune) ์ฑ๋ฅ์ ์ฐ์ถํ ๊ฒฐ๊ณผ, ๋๋ค์์ ๊ฒฐ๊ณผ์์ ์ข์ ์ฑ๋ฅ์ ๋ณด์
์ด ๋ ผ๋ฌธ์ ๊ฒฐ๋ก ์์ ์ ์๋ค์ Flamingo ๋ชจ๋ธ์ ์ ์ํ๋ฉฐ, ์ด๋ฏธ์ง ๋ฐ ๋น๋์ค ์์ ์ ์ต์ํ์ ์์ ๋ณ ํ๋ จ ๋ฐ์ดํฐ๋ง์ผ๋ก ์ ์ฉ ๊ฐ๋ฅํ ๋ฒ์ฉ ๋ชจ๋ธ์ด๋ผ๊ณ ์ค๋ช ํจ.
Flamingo๋ ์ ํต์ ์ธ ๋น์ ๋ฒค์น๋งํฌ๋ฅผ ๋์ด ๋ํ์ ๊ฐ์ ์ํธ์์ฉ ๊ธฐ๋ฅ์ ๋ณด์ฌ์ฃผ๋ฉฐ, ๋ค์ํ ์๊ฐ์ ์์ ์์ ๊ฐ๋ ฅํ ์ฑ๋ฅ์ ๋ฐํํ์๊ณ , ๋ํ, ์ฌ์ ํ์ต๋ ๋ํ ์ธ์ด ๋ชจ๋ธ๊ณผ ๊ฐ๋ ฅํ ๋น์ ๋ชจ๋ธ์ ์ฐ๊ฒฐํ๋ ๊ฒ์ด ๋ฒ์ฉ ์๊ฐ ์ดํด๋ก ๋์๊ฐ๋ ์ค์ํ ๋จ๊ณ์์ ๊ฐ์กฐํจ.
github๋ฅผ ์ฐธ๊ณ ํ์ฌ ์ค์น ๋ฐ ํ์ฉํ ์ ์์.
# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
import torch
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)
"""
Step 1: Load images
"""
demo_image_one = Image.open(
requests.get(
"http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
).raw
)
demo_image_two = Image.open(
requests.get(
"http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
stream=True
).raw
)
query_image = Image.open(
requests.get(
"http://images.cocodataset.org/test-stuff2017/000000028352.jpg",
stream=True
).raw
)
"""
Step 2: Preprocessing images
Details: For OpenFlamingo, we expect the image to be a torch tensor of shape
batch_size x num_media x num_frames x channels x height x width.
In this case batch_size = 1, num_media = 3, num_frames = 1,
channels = 3, height = 224, width = 224.
"""
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
"""
Step 3: Preprocessing text
Details: In the text we expect an <image> special token to indicate where an image is.
We also expect an <|endofchunk|> special token to indicate the end of the text
portion associated with an image.
"""
tokenizer.padding_side = "left" # For generation padding tokens should be on the left
lang_x = tokenizer(
["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of"],
return_tensors="pt",
)
"""
Step 4: Generate text
"""
generated_text = model.generate(
vision_x=vision_x,
lang_x=lang_x["input_ids"],
attention_mask=lang_x["attention_mask"],
max_new_tokens=20,
num_beams=3,
)
print("Generated text: ", tokenizer.decode(generated_text[0]))