๐Ÿ‘ฉโ€๐ŸŽจ GPU Parallel Reduction์„ AI ๊ด€์ ์—์„œ ์ดํ•ดํ•˜๊ธฐ

๊น€์ˆ˜์ง„ยท2026๋…„ 5์›” 11์ผ

AI ์—ฐ์‚ฐ์—์„œ์˜ Reduction

Reduction์€ ๋‹จ์ˆœ ๋ฐฐ์—ด ํ•ฉ ๊ณ„์‚ฐ์ด ์•„๋‹ˆ๋ผ, ์‹ค์ œ AI ๋ชจ๋ธ ๋‚ด๋ถ€์—์„œ๋„ ๋งค์šฐ ์ž์ฃผ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด:

  • Loss ๊ณ„์‚ฐ
  • Softmax
  • Attention score ํ•ฉ์‚ฐ
  • Gradient accumulation
  • Batch normalization ํ‰๊ท  ๊ณ„์‚ฐ
  • Matrix multiplication์˜ partial sum

๋“ฑ ๋Œ€๋ถ€๋ถ„์˜ ๋”ฅ๋Ÿฌ๋‹ ์—ฐ์‚ฐ ๋‚ด๋ถ€์—๋Š” Reduction์ด ํฌํ•จ๋ฉ๋‹ˆ๋‹ค.

ํŠนํžˆ ๋”ฅ๋Ÿฌ๋‹ ํ•™์Šต์—์„œ๋Š” ์ˆ˜์ฒœ~์ˆ˜๋งŒ ๊ฐœ์˜ ๊ฐ’๋“ค์„ ํ•ฉ์น˜๊ฑฐ๋‚˜ ํ‰๊ท ๋‚ด๋Š” ๊ณผ์ •์ด ๋ฐ˜๋ณต๋˜๋Š”๋ฐ, ์ด ์—ฐ์‚ฐ์ด ๋А๋ฆฌ๋ฉด ์ „์ฒด ํ•™์Šต ์†๋„๋„ ํฌ๊ฒŒ ๊ฐ์†Œํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

์˜ค๋Š˜์€ GPU์—์„œ AI ์—ฐ์‚ฐ์ด ์–ด๋–ป๊ฒŒ ๋ณ‘๋ ฌ์ฒ˜๋ฆฌ ๋˜๋Š”์ง€ ์‚ดํŽด๋ณด๊ณ ์ž ๋ธ”๋กœ๊ทธ๊ธ€์„ ์ž‘์„ฑํ•˜๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

Reduction์ด๋ž€ ?

Reduction์€ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ๊ฐ’์„ ํ•˜๋‚˜์˜ ๊ฐ’์œผ๋กœ ์ค„์ด๋Š” ์—ฐ์‚ฐ์ž…๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด ๋ฐฐ์—ด ์•ˆ์— ์ˆซ์ž๊ฐ€ ์—ฌ๋Ÿฌ ๊ฐœ ์žˆ๋‹ค๊ณ  ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
[7.0, 2.1, 5.3, 9.0, 11.2]

์ด ์ˆซ์ž๋“ค์„ ๋ชจ๋‘ ๋”ํ•˜๋ฉด ์ตœ์ข…์ ์œผ๋กœ 34.6์ด๋ผ๋Š” ํ•˜๋‚˜์˜ ๊ฐ’์ด ๋‚˜์˜ต๋‹ˆ๋‹ค.
7.0 + 2.1 + 5.3 + 9.0 + 11.2 = 34.6

์ด์ฒ˜๋Ÿผ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ๋ฐ์ดํ„ฐ๋ฅผ ํ•˜๋‚˜์˜ ๊ฒฐ๊ณผ๊ฐ’์œผ๋กœ ๋งŒ๋“œ๋Š” ๊ณผ์ •์„ Reduction์ด๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

Reduction์€ ๋ง์…ˆ์—๋งŒ ์‚ฌ์šฉ๋˜๋Š” ๊ฐœ๋…์€ ์•„๋‹ˆ๊ณ , ๋งŽ์€ ์ˆซ์ž๋“ค์„ ์ค„์ด๋Š” ๊ณผ์ • ๋ชจ๋‘ Reduction์ž…๋‹ˆ๋‹ค. ๋ฐฐ์—ด์—์„œ ๊ฐ€์žฅ ํฐ ๊ฐ’์„ ์ฐพ๋Š” ๊ฒƒ๋„ Reduction์ด๊ณ , ๊ฐ€์žฅ ์ž‘์€ ๊ฐ’์„ ์ฐพ๋Š” ๊ฒƒ๋„ Reduction์ž…๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, [7, 2, 5, 9]
์ด ๋ฐฐ์—ด์—์„œ ์ตœ๋Œ“๊ฐ’์„ ์ฐพ์œผ๋ฉด ๊ฒฐ๊ณผ๋Š” 9์ž…๋‹ˆ๋‹ค.
๊ฒฐ๊ตญ ์—ฌ๋Ÿฌ ๊ฐ’ ์ค‘์—์„œ ํ•˜๋‚˜์˜ ์ตœ๋Œ“๊ฐ’๋งŒ ๋‚จ์•˜๊ธฐ ๋•Œ๋ฌธ์— ์ด๊ฒƒ๋„ Reduction์ž…๋‹ˆ๋‹ค.

์ฆ‰, Reduction์€ ์‰ฝ๊ฒŒ ๋งํ•ด, ๋งŽ์€ ๋ฐ์ดํ„ฐ ์ค‘์—์„œ ํ•˜๋‚˜์˜ ๋Œ€ํ‘œ๊ฐ’์„ ๋ฝ‘์•„๋‚ด๋Š” ๊ณผ์ •์ด๋ผ๊ณ  ์ดํ•ดํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

์—ฌ๊ธฐ์„œ ์ค‘์š”ํ•œ ์ ์€ Reduction์ด ๋‘ ๊ฐœ์˜ ๊ฐ’์„ ๋น„๊ตํ•˜๊ฑฐ๋‚˜ ๊ณ„์‚ฐํ•˜๋Š” ์—ฐ์‚ฐ์„ ๋ฐ˜๋ณตํ•ด์„œ ์ „์ฒด ๊ฒฐ๊ณผ๋ฅผ ๋งŒ๋“ ๋‹ค๋Š” ๊ฒƒ ์ž…๋‹ˆ๋‹ค.
์˜ˆ๋ฅผ ๋“ค์–ด ํ•ฉ๊ณ„๋ฅผ ๊ตฌํ•  ๋•Œ๋Š” ๋‘ ๊ฐ’์„ ๊ณ„์† ๋”ํ•˜๊ณ , ์ตœ๋Œ“๊ฐ’์„ ๊ตฌํ•  ๋•Œ๋Š” ๋‘ ๊ฐ’ ์ค‘ ๋” ํฐ ๊ฐ’์„ ๊ณ„์† ์„ ํƒํ•ฉ๋‹ˆ๋‹ค.

Reduction์—์„œ์˜ ๋ณ‘๋ ฌ

  • Parallel Reduction = ๋ณ‘๋ ฌ๋กœ reduction์„ ์ˆ˜ํ–‰ํ•˜๋Š” ์ „์ฒด ๊ฐœ๋…
  • Reduction Tree = ๊ทธ ๋ณ‘๋ ฌ reduction์„ ํŠธ๋ฆฌ ๊ตฌ์กฐ๋กœ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•

์ˆœ์ฐจ ๋ฐฉ์‹์—์„œ๋Š” 8๊ฐœ์˜ ๊ฐ’์„ ์ฒ˜๋ฆฌํ•˜๋ ค๋ฉด ๊ฑฐ์˜ 8๋ฒˆ์— ๊ฐ€๊นŒ์šด ๋‹จ๊ณ„๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

  • 3์ด๋ž‘ 1 ๋น„๊ต โ†’ 3
  • 3์ด๋ž‘ 7 ๋น„๊ต โ†’ 7
  • 7์ด๋ž‘ 0 ๋น„๊ต โ†’ 7
  • 7์ด๋ž‘ 4 ๋น„๊ต โ†’ 7

reduction tree์—์„œ๋Š” 3๋‹จ๊ณ„๋ฉด ๋๋‚ฉ๋‹ˆ๋‹ค.

  • 8๊ฐœ ๋ฐ์ดํ„ฐ โ†’ logโ‚‚8 = 3๋‹จ๊ณ„
  • 1024๊ฐœ ๋ฐ์ดํ„ฐ โ†’ logโ‚‚1024 = 10๋‹จ๊ณ„

์ฆ‰, ์‹œ๊ฐ„๋ณต์žก๋„๋Š” O(N) -> O(log N)์œผ๋กœ ์ค„์–ด๋“ ๋‹ค๋Š” ์žฅ์ ์„ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

Reduction Tree ๊ธฐ๋ฐ˜์˜ Parallel Reduction ๊ณผ์ •

๋จผ์ € ๋งจ ์œ„๋ฅผ ๋ณด๋ฉด Thread 0๋ถ€ํ„ฐ Thread 7๊นŒ์ง€ ์ด 8๊ฐœ์˜ thread๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
๊ทธ๋ฆฌ๊ณ  ์•„๋ž˜์˜ 0~15 ์ˆซ์ž๋Š” input array์˜ index๋ฅผ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, ์ด 16๊ฐœ์˜ ๋ฐ์ดํ„ฐ๋ฅผ reductionํ•˜๋Š” ์ƒํ™ฉ์ž…๋‹ˆ๋‹ค.
๊ฐ thread๋Š” ์ง์ˆ˜ index ํ•˜๋‚˜๋ฅผ ๋‹ด๋‹นํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด:

Thread 0 โ†’ input[0]
Thread 1 โ†’ input[2]
Thread 2 โ†’ input[4]
...
Thread 7 โ†’ input[14]

์ด์ฒ˜๋Ÿผ ๊ฐ thread๋Š” ์ž๊ธฐ ์œ„์น˜๋ฅผ ๋‹ด๋‹นํ•˜๊ณ , ์ž๊ธฐ ์œ„์น˜๋งŒ ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ owner computes ๋ฐฉ์‹์ด๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

์ด ๊ทธ๋ฆผ์—์„œ ํ•ต์‹ฌ์€ ์„ธ ๊ฐ€์ง€์ธ๋ฐ์š”.

  1. ์ฒซ ๋ฒˆ์งธ๋กœ stride ๊ฐ’์ด 1 โ†’ 2 โ†’ 4 โ†’ 8์ฒ˜๋Ÿผ ๊ณ„์† 2๋ฐฐ์”ฉ ์ฆ๊ฐ€ํ•œ๋‹ค๋Š” ์ ์ž…๋‹ˆ๋‹ค.
  2. ๋‘ ๋ฒˆ์งธ๋กœ active thread ์ˆ˜๋Š” 8 โ†’ 4 โ†’ 2 โ†’ 1์ฒ˜๋Ÿผ ์ ˆ๋ฐ˜์”ฉ ๊ฐ์†Œํ•ฉ๋‹ˆ๋‹ค.
  3. ์„ธ ๋ฒˆ์งธ๋กœ partial sum์ด ์ ์  ๋” ํฐ ๋ฒ”์œ„๋ฅผ ํฌํ•จํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

control divergence

์—ฌ๊ธฐ์— ๋ฌธ์ œ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค.
์•ž์—์„œ ๋ณธ ๊ทธ๋ฆผ์—์„œ๋Š” stride๊ฐ€ ์ด๋ ‡๊ฒŒ ์ฆ๊ฐ€ํ–ˆ์Šต๋‹ˆ๋‹ค.

stride: 1 โ†’ 2 โ†’ 4 โ†’ 8

๊ทธ๋Ÿฌ๋ฉด active thread๋Š” ์ด๋ ‡๊ฒŒ ๋–จ์–ด์ ธ ์žˆ๊ฒŒ ๋˜๊ฒ ์ฃ .

1๋‹จ๊ณ„: Thread 0, 1, 2, 3, 4, 5, 6, 7
2๋‹จ๊ณ„: Thread 0, 2, 4, 6
3๋‹จ๊ณ„: Thread 0, 4
4๋‹จ๊ณ„: Thread 0

๋ฌธ์ œ๋Š” GPU์—์„œ thread๋“ค์ด ๋ณดํ†ต warp๋ผ๋Š” ๋‹จ์œ„๋กœ ๋ฌถ์—ฌ ์‹คํ–‰๋œ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋ณดํ†ต ํ•œ warp๋Š” 32๊ฐœ์˜ thread๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ๋ฐ ํ•œ warp ์•ˆ์—์„œ ์–ด๋–ค thread๋Š” ๊ณ„์‚ฐํ•˜๊ณ , ์–ด๋–ค thread๋Š” ๊ณ„์‚ฐํ•˜์ง€ ์•Š์œผ๋ฉด ๋ฌธ์ œ๊ฐ€ ์ƒ๊น๋‹ˆ๋‹ค.

์‹ค์ œ๋กœ ๋…ผ๋ฌธ์—์„œ๋Š” input size๊ฐ€ 256์ผ ๋•Œ ๊ธฐ์กด ๋ฐฉ์‹์˜ ์ž์› ํ™œ์šฉ ํšจ์œจ์ด ์•ฝ 35%๋ฐ–์— ์•ˆ ๋œ๋‹ค๊ณ  ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

๋”ฐ๋ผ์„œ ๊ฐœ์„ ๋œ ๋ฐฉ์‹์ด ๋‚˜์˜ค๋Š”๋ฐ์š”.
์•„์ด๋””์–ด๋Š” active thread๋“ค์„ ๋ฉ€๋ฆฌ ๋–จ์–ด๋œจ๋ฆฌ์ง€ ๋ง๊ณ , ์•ž์ชฝ์— ์—ฐ์†์ ์œผ๋กœ ๋ชจ์•„๋‘์ž.์ž…๋‹ˆ๋‹ค.

์ฆ‰ ๊ธฐ์กด ๋ฐฉ์‹์€ active thread๊ฐ€ ์ด๋ ‡๊ฒŒ ๋๋‹ค๋ฉด:

  • Thread 0, 2, 4, 6

๊ฐœ์„  ๋ฐฉ์‹์€ ์ด๋ ‡๊ฒŒ ๋˜๊ฒŒ ๋งŒ๋“œ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

  • Thread 0, 1, 2, 3

๊ทธ๋Ÿฌ๋ฉด warp ๋‹จ์œ„๋กœ ๋ดค์„ ๋•Œ, ์–ด๋–ค warp๋Š” ์ „๋ถ€ ์ผํ•˜๊ณ  ์–ด๋–ค warp๋Š” ์ „๋ถ€ ์‰ฌ๊ฒŒ ๋งŒ๋“ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌ๋ฉด ํ•œ warp ์•ˆ์—์„œ ๊ณ„์‚ฐํ•˜๋Š” thread์™€ ์‰ฌ๋Š” thread๊ฐ€ ์„ž์ด๋Š” ์ƒํ™ฉ์ด ์ค„์–ด๋“ญ๋‹ˆ๋‹ค.

memory ์ ‘๊ทผ ํšจ์œจ

global memory๋ฅผ ์ž์ฃผ ์“ฐ์ง€ ์•Š๊ณ , shared memory๋ฅผ ์ž์ฃผ ์‚ฌ์šฉํ•˜์—ฌ, memory ์ ‘๊ทผ ํšจ์œจ์„ ๋Š˜๋ ค๋ณด์ž๋Š” ์•„์ด๋””์–ด๋ฅผ ๋– ์˜ฌ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ฒ˜์Œ์—” global memory๋ฅผ ํ†ตํ•ด์„œ input ๊ฐ’์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค. ๊ทธ ์ดํ›„ ๊ณ„์‚ฐ ๋ถ€ํ„ฐ๋Š” global memory๊ฐ€ ์•„๋‹ˆ๋ผ shared memory์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, ๋‹ค์Œ ๊ณผ์ • reduction ๋ถ€ํ„ฐ๋Š” shared memory ์•ˆ์—์„œ๋งŒ ์ง„ํ–‰ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋งˆ์ง€๋ง‰์œผ๋กœ ์ตœ์ข… ๊ฒฐ๊ณผ๋งŒ global memory์˜ output์— ํ•œ ๋ฒˆ ์ €์žฅํ•˜๋ฉด memory ์ ‘๊ทผ ํšจ์œจ์„ฑ์„ ์˜ฌ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋งˆ๋ฌด๋ฆฌ

Reduction์€ ๋‹จ์ˆœํ•œ ๋ฐฐ์—ด ํ•ฉ์‚ฐ ์•Œ๊ณ ๋ฆฌ์ฆ˜์ฒ˜๋Ÿผ ๋ณด์ด์ง€๋งŒ, ์‹ค์ œ๋กœ๋Š” AI ์—ฐ์‚ฐ๊ณผ GPU ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ์˜ ํ•ต์‹ฌ์ด ๋˜๋Š” ํŒจํ„ด์ž…๋‹ˆ๋‹ค.

ํŠนํžˆ:

  • control divergence ์ตœ์†Œํ™”
  • memory ์ ‘๊ทผ ์ตœ์ ํ™”
  • warp ๋‹จ์œ„ ์‹คํ–‰ ๊ตฌ์กฐ ์ดํ•ด

๊ฐ™์€ ๊ฐœ๋…๋“ค์€ CUDA ์ตœ์ ํ™”์—์„œ ๋งค์šฐ ์ค‘์š”ํ•œ ์š”์†Œ์ž…๋‹ˆ๋‹ค.

๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์ด ์ปค์งˆ์ˆ˜๋ก ์—ฐ์‚ฐ๋Ÿ‰๋„ ํญ๋ฐœ์ ์œผ๋กœ ์ฆ๊ฐ€ํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ์•ž์œผ๋กœ์˜ AI ์‹œ์Šคํ…œ์—์„œ๋Š” ๋‹จ์ˆœ ๋ชจ๋ธ ๊ตฌ์กฐ๋ฟ ์•„๋‹ˆ๋ผ GPU ์—ฐ์‚ฐ ์ตœ์ ํ™” ์—ญ์‹œ ๋”์šฑ ์ค‘์š”ํ•ด์งˆ ๊ฒƒ์ด๋ผ๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค.

0๊ฐœ์˜ ๋Œ“๊ธ€