Flamingo: a Visual Language Model for Few-Shot Learning

ํ•œ์Šน์šฐยท2024๋…„ 9์›” 16์ผ

paper-review

๋ชฉ๋ก ๋ณด๊ธฐ
3/4
post-thumbnail

์›๋ณธ ๋…ผ๋ฌธ

Flamingo: a Visual Language Model for Few-Shot Learning

Advances in neural information processing systems 2022

์š”์•ฝ

Flamingo๐Ÿฆฉ๋ผ๋Š” ๋น„์ฃผ์–ผ ์–ธ์–ด ๋ชจ๋ธ์„ ์†Œ๊ฐœ
์ด ๋ชจ๋ธ์€ ์ฃผ์–ด์ง„ ์†Œ์ˆ˜์˜ ์˜ˆ์‹œ๋กœ ์ƒˆ๋กœ์šด ์ž‘์—…์„ ๋น ๋ฅด๊ฒŒ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋Š” ๋Šฅ๋ ฅ์„ ๊ฐ–์ถ”๊ณ  ์žˆ์œผ๋ฉฐ, ํ…์ŠคํŠธ์™€ ์ด๋ฏธ์ง€(๋น„๋””์˜ค)๋ฅผ ๊ฒฐํ•ฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•™์Šตํ•จ. ์ด ๋ชจ๋ธ์€ ๋‹ค์–‘ํ•œ ๋ฒค์น˜๋งˆํฌ์—์„œ ํƒ์›”ํ•œ ์„ฑ๊ณผ๋ฅผ ๋ณด์˜€์œผ๋ฉฐ, ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ ์—†์ด๋„ ๊ฐ„๋‹จํ•œ ์˜ˆ์‹œ๋“ค๋กœ ๋†’์€ ์„ฑ๋Šฅ์„ ๋‹ฌ์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ์ด ์ฃผ์š” ๊ธฐ์—ฌ์ 

1. ๋„์ž…๋ถ€

ํ˜„์žฌ ์ปดํ“จํ„ฐ ๋น„์ „์—์„œ ์ƒˆ๋กœ์šด ์ž‘์—…์„ ๋น ๋ฅด๊ฒŒ ๋ฐฐ์šฐ๋Š” ๊ฒƒ์— ๋Œ€ํ•œ ์—ฐ๊ตฌ๊ฐ€ ์ž˜ ์ˆ˜ํ–‰๋˜๊ณ  ์žˆ์ง€๋งŒ, ๋Œ€๋ถ€๋ถ„์˜ ๋ฐฉ๋ฒ•๋“ค์€ ์—ฌ์ „ํžˆ ๋Œ€๊ทœ๋ชจ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•ด ์‚ฌ์ „ ํ•™์Šตํ•œ ํ›„, ๊ด€์‹ฌ ์žˆ๋Š” ์ž‘์—…์— ๋งž๊ฒŒ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ๋ฐฉ์‹์— ์˜์กดํ•˜๊ณ  ์žˆ์Œ. ๊ทธ๋Ÿฌ๋‚˜ ์„ฑ๊ณต์ ์ธ ๋ฏธ์„ธ ์กฐ์ •์„ ์œ„ํ•ด์„œ๋Š” ์ˆ˜์ฒœ ๊ฐœ์˜ ์ฃผ์„์ด ๋‹ฌ๋ฆฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ๊ฐ€ ํ•„์š”ํ•˜๋ฉฐ, ์ž‘์—…๋ณ„ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์กฐ์ •์ด ํ•„์š”ํ•  ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์ž์›์ด ๋งŽ์ด ์†Œ๋ชจ๋จ.

์ตœ๊ทผ์—๋Š” ๋Œ€์กฐ์  ๋ชฉ์ (contrastive objective) ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•˜์—ฌ ํ•™์Šต๋œ ๋‹ค์ค‘ ๋ชจ๋‹ฌ ๋น„์ „-์–ธ์–ด ๋ชจ๋ธ์ด ๋ฏธ์„ธ ์กฐ์ • ์—†์ด๋„ ์ƒˆ๋กœ์šด ์ž‘์—…์— ์ œ๋กœ์ƒท ์ ์‘์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•จ. ํ•˜์ง€๋งŒ ์ด๋Ÿฌํ•œ ๋ชจ๋ธ์€ ํ…์ŠคํŠธ์™€ ์ด๋ฏธ์ง€ ๊ฐ„ ์œ ์‚ฌ์„ฑ ์ ์ˆ˜๋งŒ์„ ์ œ๊ณตํ•  ์ˆ˜ ์žˆ์–ด ๋ถ„๋ฅ˜์™€ ๊ฐ™์€ ์ œํ•œ๋œ ์‚ฌ์šฉ ์‚ฌ๋ก€์—๋งŒ ์ ์šฉ๋  ์ˆ˜ ์žˆ๋Š” ํ•œ๊ณ„๊ฐ€ ์žˆ์Œ.

์ด๋“ค์€ ์–ธ์–ด ์ƒ์„ฑ ๋Šฅ๋ ฅ์ด ๋ถ€์กฑํ•˜์—ฌ ์บก์…”๋‹์ด๋‚˜ ๋น„์ฃผ์–ผ ์งˆ๋ฌธ ์‘๋‹ต๊ณผ ๊ฐ™์€ ์ž‘์—…์—๋Š” ์ ํ•ฉํ•˜์ง€ ์•Š์Œ. ์ด๋ฅผ ๊ทน๋ณตํ•˜๊ธฐ ์œ„ํ•ด ์‹œ๊ฐ์ ์ธ ๋‚ด์šฉ์— ๋Œ€ํ•œ ์–ธ์–ด ์ƒ์„ฑ์— ๋Œ€ํ•œ ํƒ๊ตฌ๋„ ์ด๋ฃจ์–ด์กŒ์ง€๋งŒ, ์•„์ง ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์ฃผ์ง€ ๋ชปํ•˜๊ณ  ์žˆ์Œ.

๋ณธ ๋…ผ๋ฌธ์€ ์ด๋Ÿฌํ•œ ํ•œ๊ณ„๋ฅผ ๊ทน๋ณตํ•˜๊ณ , ๋‹ค์–‘ํ•œ ๋น„์ „ ๋ฐ ์–ธ์–ด ์ž‘์—…์—์„œ ์†Œ์ˆ˜์˜ ์˜ˆ์‹œ๋งŒ์œผ๋กœ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋Š” ๋น„์ฃผ์–ผ ์–ธ์–ด ๋ชจ๋ธ์ธ Flamingo๋ฅผ ์†Œ๊ฐœํ•จ. ์ด ๋ชจ๋ธ์€ ๋ช‡ ๊ฐ€์ง€ ์ž…๋ ฅ/์ถœ๋ ฅ ์˜ˆ์‹œ๋งŒ์œผ๋กœ ๋‹ค์–‘ํ•œ ์ž‘์—…์— ์ ์šฉ๋  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๋ฏธ์„ธ ์กฐ์ • ์—†์ด๋„ ์—ฌ๋Ÿฌ ์ž‘์—…์—์„œ ์ƒˆ๋กœ์šด ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์คŒ.

Flamingo ๊ฒฐ๊ณผ ์˜ˆ์‹œ

(๋ฌด์Šจ ๋™๋ฌผ์— ๋Œ€ํ•œ ๋Œ€๋‹ต)

(๋™๋ฌผ ์ˆ˜์— ๋Œ€ํ•œ ๋Œ€๋‹ต)

(์˜์ƒ ๋‚ด์šฉ ์„ค๋ช…)

2. ์ƒ์„ธ์„ค๋ช…

Flamingo๋Š” ์œ„ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด Vision encoder, Language model(LM) block๋กœ ๋‚˜๋‰˜๋ฉฐ, Perceiver Resampler(์—ฐ๋ณด๋ผ์ƒ‰), Gated XATTN-Dense๊ฐ€ ์žˆ๋Š” ๊ฒƒ์ด ํŠน์ง•

Vision encoder : CLIP text-image contrastive learning์œผ๋กœ ํ•™์Šต๋œ ๋น„์ „ ์ธ์ฝ”๋”๋ฅผ ๋ถˆ๋Ÿฌ์™€ ์‚ฌ์šฉํ•จ

Language model : Large text corpus๋กœ ํ•™์Šต๋œ Chinchilla๋ฅผ ์‚ฌ์ „ํ•™์Šต ๋ถˆ๋Ÿฌ์™€ ์‚ฌ์šฉํ•จ

Perceiver Resampler : ๋น„์ „ ์ธ์ฝ”๋”์˜ ์ถœ๋ ฅ์„ ๊ณ ์ •๋œ ํฌ๊ธฐ๋กœ mappingํ•˜๋Š” ์—ญํ• (๊ณ ์ • ํฌ๊ธฐ๋งŒํผ ์ž‘์•„์ ธ์„œ ๊ณ„์‚ฐ ํšจ์œจ์ )

Gated XATTN-Dense : Query, Key, Value๋ฅผ ์ž…๋ ฅ ๋ฐ›์•„ ํ•ด๋‹น ์ •๋ณด๊ฐ€ ๊ฐ€๋ฏธ๋œ ๋ฒกํ„ฐ๋ฅผ ์ถœ๋ ฅํ•˜๋Š” ๋ ˆ์ด์–ด

๋‹ค์Œ์œผ๋กœ Perceiver Resampler์™€ Gated XATTN-Dense์— ๋Œ€ํ•ด์„œ ์„ค๋ช…

Perceiver Resampler

์šฐ์„  Perceiver Resampler๋Š” ๋ณธ ๋…ผ๋ฌธ(2022)์—์„œ ์ƒˆ๋กญ๊ฒŒ ์ œ์•ˆํ•œ ๊ฑด ์•„๋‹ˆ๊ณ  2021๋…„ Perceiver: General Perception with Iterative Attention์—์„œ ์ œ์•ˆ๋จ. ๋ณธ ๋…ผ๋ฌธ์€ Cross Attention์— ์ž…๋ ฅํ•˜๊ธฐ ์œ„ํ•œ ๋ฆฌ์ƒ˜ํ”Œ๋ง ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•œ ๊ฒƒ์ด๋ผ ๋ณด๋ฉด ๋จ.

์šฐ์„  ์„ค๋ช… ์ „์—, Flamingo ๋ชจ๋ธ์—์„œ๋Š” ์šฐ์„  Vision Feature์™€ Text Feature๋ฅผ ๊ฐ™์ด ์—ฐ์‚ฐํ•ด์ฃผ์–ด ํ•œ๋‹ค๋Š” ๊ฑธ ์ƒ๊ฐํ•ด๋ณด์ž.

์ด๋•Œ Vision Feature๋Š” Text Feature๋ณด๋‹ค ์ผ๋ฐ˜์ ์œผ๋กœ ํ›จ์”ฌ ํฐ ์ฐจ์›์„ ๊ฐ–๊ธฐ ๋•Œ๋ฌธ์— ์ด ๋‘˜์„ ๋™์‹œ์— ์—ฐ์‚ฐํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” Vision Feature์˜ ์ฐจ์›์„ Text Feature์— ๋งž๊ฒŒ ์ถ•์†Œํ•ด์ค„ ํ•„์š”๊ฐ€ ์žˆ์Œ. (์•ˆ ๊ทธ๋Ÿผ ๋ฒกํ„ฐ์˜ ์ฐจ์›๋„ ์•ˆ ๋งž๊ณ  ์—ฐ์‚ฐ ์ž์ฒด๊ฐ€ ๋ถˆ๊ฐ€๋Šฅํ•˜๋‹ค)

์œ„ ๊ทธ๋ฆผ์—์„œ๋Š” ์•„๋ž˜์ชฝ์—์„œ ๊ณ ์ฐจ์› ๋ฒกํ„ฐ์— ํ•ด๋‹นํ•˜๋Š” Vision Feature๋ฅผ ์ดˆ๋ก์ƒ‰ ๊ณ„์—ด๋กœ ํ‘œํ˜„ํ•˜๊ณ  ์žˆ๋Š”๋ฐ, ์ด์ œ ์ด Vision Feature๋ฅผ ์ €์ฐจ์›์œผ๋กœ ์ถ•์†Œํ•ด์•ผ ํ•จ. (์ด๋•Œ์˜ ํƒ€๊ฒŸ ๋ฒกํ„ฐ(Latent)๋ฅผ ํšŒ์ƒ‰์œผ๋กœ ํ‘œํ˜„)

Query๋กœ ๋“ค์–ด๊ฐ€๋Š” ์ €์ฐจ์› ๋ฒกํ„ฐ๋Š” Learned Latent Vector(๊ทธ๋ฆผ์˜ ์œ„์น˜์ •๋ณด๋ฅผ ์‹ ๊ฒฝ๋ง์— ์ž…๋ ฅํ•ด์„œ ๋‚˜์˜จ ๋ฒกํ„ฐ)๋ฅผ ์‚ฌ์šฉํ•จ. ์ด๋ ‡๊ฒŒ Learned Latent Vector๋Š” Query๋กœ, Vision Feature๋ฅผ Key, Value๋กœ ํ•˜์—ฌ Cross Attention์„ ์ˆ˜ํ–‰

Gated XATTN-Dense

Gated XATTN-Dense๋Š” ์•ž์„œ Vision ์ •๋ณด์™€ Language ์ •๋ณด๋ฅผ ์œตํ•ฉ(Cross Attention)ํ•˜์—ฌ Language๊ฐ€ ๊ฐ€๋ฏธ๋œ ๋น„์ „ ๋ฒกํ„ฐ๋ฅผ ์–ป๋Š”๋‹ค๊ณ  ์ƒ๊ฐํ•˜๋ฉด ํŽธํ•จ. ์—ฌ๊ธฐ์„œ ์ž…๋ ฅ์œผ๋กœ ์š”๊ตฌ๋˜๋Š” Query, Key, Value๋ฅผ ์ •ํ•ด์•ผํ•œ๋‹ค. ๋ณธ ๋ชจ๋ธ์€ ๋น„์ฃผ์–ผ-์–ธ์–ด ๋ชจ๋ธ์ด๋ฏ€๋กœ ์ตœ์ข… ์ถœ๋ ฅ์€ ํ…์ŠคํŠธ์ž„. ๋”ฐ๋ผ์„œ Query ๋กœ๋Š” Text Feature, ์ •๋ณด๋ฅผ ์ถ”๊ฐ€ํ•  Key, Value๋กœ๋Š” Vision Feature๋ฅผ ์‚ฌ์šฉ

๋ณธ ๋…ผ๋ฌธ์— ์‚ฌ์šฉํ•˜๋Š” ์—ฐ์‚ฐ ์ˆœ์„œ๋Š” ์œ„ ๊ทธ๋ฆผ์„ ํ†ตํ•ด ํ™•์ธํ•  ์ˆ˜ ์žˆ์Œ. ์„ค๋ช…ํ•˜์ž๋ฉด ๋จผ์ € visual feature(x)์— ๋Œ€ํ•ด Cross Attention(q=y(language), kv=x) ์—ฐ์‚ฐ ๋’ค์— feed forward(FFW) ๋ ˆ์ด์–ด๋ฅผ ํ†ตํ•ด weight์™€ bias๋ฅผ (wx+b)๊ณ„์‚ฐํ•˜๊ณ , ์ดํ›„ ๊ทธ ๊ฐ’(y)์— ๋‹ค์‹œ self attenction(q=y, kv=y)๋ฅผ ์—ฐ์‚ฐ, FFW ๋ ˆ์ด์–ด๋ฅผ ๊ฑฐ์ณ ์ตœ์ข… y๋ฅผ ์ถœ๋ ฅํ•จ

3. ์‹คํ—˜ ๊ฒฐ๊ณผ

๋ณธ ๋…ผ๋ฌธ์—์„œ ์‚ฌ์šฉ๋œ ๋ฐ์ดํ„ฐ ์…‹์€ ์•„๋ž˜์™€ ๊ฐ™์Œ

  • M3W (MultiModal MassiveWeb): ์ด ๋ฐ์ดํ„ฐ์…‹์€ 4,300๋งŒ ๊ฐœ์˜ ์›นํŽ˜์ด์ง€์—์„œ ์ˆ˜์ง‘๋œ ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์œผ๋ฉฐ, ๊ฐ ์›นํŽ˜์ด์ง€์—์„œ ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ์˜ ์œ„์น˜ ๊ด€๊ณ„๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์‹œ๊ฐ ๋ฐ์ดํ„ฐ๋ฅผ ์ถ”์ถœ
  • ALIGN: 18์–ต ๊ฐœ์˜ ์ด๋ฏธ์ง€์™€ ๋Œ€์ฒด ํ…์ŠคํŠธ(alt-text) ์Œ์œผ๋กœ ๊ตฌ์„ฑ๋œ ๋ฐ์ดํ„ฐ์…‹
  • LTIP (Long Text & Image Pairs): 3์–ต 1,200๋งŒ ๊ฐœ์˜ ์ด๋ฏธ์ง€์™€ ๊ธด ์„ค๋ช… ํ…์ŠคํŠธ ์Œ์œผ๋กœ ๊ตฌ์„ฑ๋œ ๋ฐ์ดํ„ฐ์…‹
  • VTP (Video & Text Pairs): ํ‰๊ท  22์ดˆ ๊ธธ์ด์˜ 2,700๋งŒ ๊ฐœ์˜ ์งง์€ ๋น„๋””์˜ค์™€ ํ•ด๋‹น ๋น„๋””์˜ค์— ๋Œ€ํ•œ ๋ฌธ์žฅ ์„ค๋ช…์œผ๋กœ ๊ตฌ์„ฑ๋œ ๋ฐ์ดํ„ฐ์…‹

์ดํ›„ ์•„๋ž˜ ๋‘ ์กฐ๊ฑด์œผ๋กœ ์‹คํ—˜ ์ง„ํ–‰

1. Few-shot learning on vision-language tasks

๋ณธ ๋…ผ๋ฌธ์€ Few Shot Learning์— ๋Œ€ํ•œ ๊ฒฐ๊ณผ๋ฅผ ์œ„ํ•ด ์ด 16๊ฐœ์˜ ๋ฒค์น˜๋งˆํฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์„ฑ๋Šฅ์„ ์ธก์ •. ๊ฐ ๋ฌธ์ œ๋ณ„(๋ฒค์น˜๋งˆํฌ)๋งˆ๋‹ค ๋ชจ๋ธ์„ Fine Tuning ํ•˜์ง€ ์•Š๊ณ  ๋‹จ์ง€ ์‚ฌ์ „ ํ•™์Šต์„ ์™„๋ฃŒํ•œ Flamingo์—๊ฒŒ ๋ช‡ ๊ฐ€์ง€ ์˜ˆ์‹œ๋ฅผ ์ œ๊ณตํ•˜๊ณ  ํ•ด๊ฒฐํ•˜๋„๋ก ํ•จ

์‹คํ—˜ ๊ฒฐ๊ณผ ๊ธฐ์กด SOTA ๋ชจ๋ธ๋“ค ๋Œ€๋น„ ์„ฑ๋Šฅ์ด ๊ฑฐ์˜ ๋ชจ๋‘ ์ข‹์•˜๋‹ค๊ณ  ํ•จ.

2. Fine-tuning Flamingo as a pretrained vision-language model

์ด๋ฒˆ์—๋Š” ๋ฌธ์ œ๋“ค์— ๋Œ€ํ•ด ์ถ”๊ฐ€ ํ•™์Šต์„ ํ•˜์—ฌ(Fine-tune) ์„ฑ๋Šฅ์„ ์‚ฐ์ถœํ•œ ๊ฒฐ๊ณผ, ๋Œ€๋‹ค์ˆ˜์˜ ๊ฒฐ๊ณผ์—์„œ ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ž„

4. ๊ฒฐ๋ก 

์ด ๋…ผ๋ฌธ์˜ ๊ฒฐ๋ก ์—์„œ ์ €์ž๋“ค์€ Flamingo ๋ชจ๋ธ์„ ์ œ์•ˆํ•˜๋ฉฐ, ์ด๋ฏธ์ง€ ๋ฐ ๋น„๋””์˜ค ์ž‘์—…์— ์ตœ์†Œํ•œ์˜ ์ž‘์—…๋ณ„ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋งŒ์œผ๋กœ ์ ์šฉ ๊ฐ€๋Šฅํ•œ ๋ฒ”์šฉ ๋ชจ๋ธ์ด๋ผ๊ณ  ์„ค๋ช…ํ•จ.

Flamingo๋Š” ์ „ํ†ต์ ์ธ ๋น„์ „ ๋ฒค์น˜๋งˆํฌ๋ฅผ ๋„˜์–ด ๋Œ€ํ™”์™€ ๊ฐ™์€ ์ƒํ˜ธ์ž‘์šฉ ๊ธฐ๋Šฅ์„ ๋ณด์—ฌ์ฃผ๋ฉฐ, ๋‹ค์–‘ํ•œ ์‹œ๊ฐ์  ์ž‘์—…์—์„œ ๊ฐ•๋ ฅํ•œ ์„ฑ๋Šฅ์„ ๋ฐœํœ˜ํ•˜์˜€๊ณ , ๋˜ํ•œ, ์‚ฌ์ „ ํ•™์Šต๋œ ๋Œ€ํ˜• ์–ธ์–ด ๋ชจ๋ธ๊ณผ ๊ฐ•๋ ฅํ•œ ๋น„์ „ ๋ชจ๋ธ์„ ์—ฐ๊ฒฐํ•˜๋Š” ๊ฒƒ์ด ๋ฒ”์šฉ ์‹œ๊ฐ ์ดํ•ด๋กœ ๋‚˜์•„๊ฐ€๋Š” ์ค‘์š”ํ•œ ๋‹จ๊ณ„์ž„์„ ๊ฐ•์กฐํ•จ.

์‚ฌ์šฉ๋ฐฉ๋ฒ•

github๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ ์„ค์น˜ ๋ฐ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ์Œ.

# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
import torch

checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)

"""
Step 1: Load images
"""
demo_image_one = Image.open(
    requests.get(
        "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
    ).raw
)

demo_image_two = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
        stream=True
    ).raw
)

query_image = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", 
        stream=True
    ).raw
)


"""
Step 2: Preprocessing images
Details: For OpenFlamingo, we expect the image to be a torch tensor of shape 
 batch_size x num_media x num_frames x channels x height x width. 
 In this case batch_size = 1, num_media = 3, num_frames = 1,
 channels = 3, height = 224, width = 224.
"""
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)

"""
Step 3: Preprocessing text
Details: In the text we expect an <image> special token to indicate where an image is.
 We also expect an <|endofchunk|> special token to indicate the end of the text 
 portion associated with an image.
"""
tokenizer.padding_side = "left" # For generation padding tokens should be on the left
lang_x = tokenizer(
    ["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of"],
    return_tensors="pt",
)


"""
Step 4: Generate text
"""
generated_text = model.generate(
    vision_x=vision_x,
    lang_x=lang_x["input_ids"],
    attention_mask=lang_x["attention_mask"],
    max_new_tokens=20,
    num_beams=3,
)

print("Generated text: ", tokenizer.decode(generated_text[0]))
profile
๋‚ด๊ฐ€ ๋ณด๋ ค๊ณ  ๋งŒ๋“  ๋…ผ๋ฌธ ๋ฆฌ๋ทฐ ์ €์žฅ์†Œ

0๊ฐœ์˜ ๋Œ“๊ธ€