[Lucid] 텐서 유틸리티 연산들의 구현

안암동컴맹·2025년 12월 10일

Lucid Development

목록 보기
5/20
post-thumbnail

🧭 텐서 유틸리티 연산들의 구현

🧱 추상화 맥락

Lucid의 연산 계층을 엄격히 나눌 때, 최하위 primitive(add, matmul 등) 위에는 형태 변환·축 조작·집계 연산이 있다. 이들은 lucid._util.func에 모여 있고, 실제로는 reshape, squeeze, stack, pad, repeat, tile, flatten, broadcast_to, where, sort/topk 같은 함수들이다. 이 층의 목표는 NumPy 호출을 이 레이어에서만 소비하면서도, gradient가 정확히 흐르도록 축/shape 정보를 치밀하게 다루는 것이다. 아래에서는 자주 쓰이는 연산들을 골라 forward 정의와 gradient 처리, 그리고 axis/keepdims 규약을 어떻게 맞췄는지 정리한다.


🔄 reshape/squeeze/unsqueeze: shape 왕복 보장

  • 경로: lucid/_util/func.py
  • 핵심 포인트: forward에서 shape를 변경하면 backward에서 원 shape로 정확히 되돌리는 reshape가 필요하다.
class reshape(operation):
    @unary_func_op()
    def cpu(self, a: Tensor):
        self.result = Tensor(a.data.reshape(self.shape))
        return self.result, partial(self.__grad__, a=a)

    def __grad__(self, a: Tensor):
        return self.result.grad.reshape(*a.shape)

squeezeunsqueeze도 동일한 패턴을 따른다. squeeze는 특정 axis를 제거하고 backward에서 reshape(a.shape)로 원복, unsqueezeexpand_dims 후 backward에서 squeeze로 축소한다. 규칙: shape 정보는 forward에서 캡처하고, backward에서는 단일 reshape으로 복원한다. 추가 로직이나 조건 분기는 모두 forward에만 둔다.

📦 stack/concatenate: 분리-병합의 쌍대성

  • 경로: stack, hstack, vstack, concatenate (lucid/_util/func.py)
  • forward: 여러 Tensor를 특정 axis로 합친다.
  • backward: 합쳐진 gradient를 입력 개수만큼 split 해서 되돌린다.
class stack(operation):
    def __grad__(self, arr: tuple[Tensor], lib_: ModuleType):
        split_grads = lib_.split(self.result.grad, len(arr), axis=self.axis)
        return tuple(split_grads)

class concatenate(operation):
    def __grad__(self, arr: tuple[Tensor, ...]):
        split_sizes = [a.shape[self.axis] for a in arr]
        grad = self.result.grad
        outputs = []
        start = 0
        for size in split_sizes:
            slicer = [slice(None)] * grad.ndim
            slicer[self.axis] = slice(start, start + size)
            outputs.append(grad[tuple(slicer)])
            start += size
        return tuple(outputs)

축 규약: 모든 입력의 ndim과 axis가 일치한다고 가정하며, 브로드캐스트는 하지 않는다. gradient는 forward에서 쌓인 순서를 그대로 split 하여 돌려준다.

🧱 pad: 패딩 구간 잘라내기

  • 경로: lucid/_util/func.py
  • forward: np.pad/mx.pad 호출 전 pad_width를 (before, after) 튜플 리스트로 정규화한다.
  • backward: 패딩을 제거한 slice만 남긴다.
def __grad__(self, a: Tensor, lib_: ModuleType):
    grad_input = lib_.zeros_like(a.data)
    slices = []
    for before, after in self.pad_with_norm:
        start = before
        end = -after if after != 0 else None
        slices.append(slice(start, end))
    grad_input = self.result.grad[tuple(slices)]
    return grad_input

규칙: forward에서 확장한 영역을 backward에서는 버린다. pad_width 정규화가 핵심이므로, 단일 int·길이 2 튜플·축별 튜플 모두 (before, after) 리스트로 변환해 재사용한다.

🔁 repeat/tile: 축별 확장과 축소

  • 경로: repeat, tile (lucid/_util/func.py)
  • forward: 특정 axis 또는 전체(flat)에서 요소를 반복.
  • backward: 반복된 위치의 gradient를 합산해 원래 위치로 축소.

repeat의 backward는 axis가 없는 경우(flat)와 특정 axis인 경우를 나눠 처리한다. 핵심은 output 인덱스 → input 인덱스 매핑을 만들고, 거기에 grad를 accumulate 하는 것.

def __grad__(self, a: Tensor, lib_: ModuleType):
    if self.axis is None:
        output_indices = np.repeat(np.arange(input_size), repeats_arr)
        np.add.at(grad_input_flat, output_indices, grad_output_flat)
        ...
    else:
        output_indices_axis = np.repeat(input_indices_axis, repeats_arr, axis=axis_)
        idx = np.stack(np.meshgrid(..., indexing="ij"))
        idx[axis_] = output_indices_axis
        np.add.at(grad_input, tuple(idx), self.result.grad)
        ...

tile은 repeat과 유사하지만 reps 배열을 shape 앞쪽에 끼워넣어 reshape 후 짝수 축에 대해 sum(axis=axes_to_sum)을 수행한다. 원리: forward에서 확장한 차원 수만큼 backward에서 sum으로 축소한다.

🧮 flatten: 구간 합치기

  • 경로: lucid/_util/func.py
  • forward: [start_axis, end_axis] 구간을 하나의 축으로 곱해 합친다.
  • backward: 저장한 original_shape로 reshape.
flat_axis = 1
for i in range(start, end + 1):
    flat_axis *= a.shape[i]
new_shape = a.shape[:start] + (flat_axis,) + a.shape[end + 1 :]
self.result = Tensor(a.data.reshape(new_shape))

def __grad__(self):
    return self.result.grad.reshape(self.original_shape)

축 규약: 음수 axis도 허용해 start/end를 실제 인덱스로 변환한다. 곱셈 순서는 forward에서 확정하고 backward는 단일 reshape만 수행한다.

🌐 broadcast_to: 확장 → 축소

  • 경로: lucid/_util/func.py
  • forward: 지정 shape으로 broadcast.
  • backward: broadcast된 축을 sum으로 축소하고, 원래 shape로 reshape.
def __grad__(self):
    input_shape = self.original_shape
    ...
    for axis, (in_dim, out_dim) in enumerate(zip(input_shape, self.shape)):
        if in_dim == 1 and out_dim > 1:
            self.result.grad = self.result.grad.sum(axis=axis, keepdims=True)
    return self.result.grad.reshape(self.original_shape)

원칙: forward에서 늘어난 축(크기 1 → n)은 backward에서 sum(axis)로 접어 넣는다. ndim이 달라진 경우 앞쪽에 (1,)*diff를 붙여 정렬한 후 검사한다.

🎯 where: 조건 분기와 zero-grad

  • 경로: lucid/_util/func.py
  • forward: np.where/mx.where로 조건 분기.
  • backward: cond에는 gradient를 흘리지 않고, a/b로만 분기해 전달.
def __grad__(self, lib_: ModuleType):
    cond = self.cond_.data
    grad = self.result.grad
    grad_cond = lib_.array(0.0)
    grad_a = lib_.where(cond, grad, 0)
    grad_b = lib_.where(lib_.logical_not(cond), grad, 0)
    return grad_cond, grad_a, grad_b

규칙: 조건 텐서에 대해서는 미분하지 않는다(항상 0). 분기된 영역은 마스크 연산으로 전달한다.

🔢 topk/sort: 인덱스 역정렬

  • 경로: sort, topk (lucid/_util/func.py)
  • forward: 값과 인덱스를 반환.
  • backward: 출력 gradient를 원래 인덱스 순서로 되돌린 뒤 입력 위치에 scatter/add.
def __grad__(self, lib_):
    grad = self.result[0].grad
    reverse_indices = lib_.argsort(self.result[1].data, axis=self.axis)
    grad_out = lib_.take_along_axis(grad, reverse_indices, axis=self.axis)
    return grad_out

topkindices에 따라 np.put_along_axis로 scatter한다. 핵심은 “정렬/선택”의 역연산을 gradient 경로에 맞춰 구현하는 것.

🧾 집계 연산과 keepdims

mean, sum, var 등 집계 연산은 reduction 축과 keepdims에 따라 gradient shape을 맞추는 것이 중요하다. (코드는 lucid/_tensor/tensor.py_util의 reduce 경로에 분포.)

일반 규칙:

  • forward: 선택한 axis를 제거하거나 keepdims=True일 경우 1로 유지.
  • backward:
    • reduction된 축이 사라졌다면 reshape(..., 1, ...)broadcast_to로 입력 shape로 확장.
    • mean은 추가로 1 / reduce_size 스케일을 곱한다.
    • var(xμ)(x-\mu)에 대해 2/N2/N 스케일을 곱하고, 필요시 keepdims 처리 후 broadcast.

이 로직은 _match_grad_shape와 동일한 철학을 따른다: 축을 없앴다면 backward에서 축을 다시 만들어 broadcast, 크기를 키웠다면 sum으로 접는다.


🧵 마무리

lucid._util의 유틸리티 연산은 모델 코드에서 자주 등장하지만, 실제로는 shape/axis bookkeepinggradient 역전을 정확히 처리하는 작은 규약 모음이다. 모든 함수가 NumPy 호출을 이 레이어에 가두고, backward에서 축/shape를 왕복시켜주는 패턴을 공유한다. 다음 문서에서는 이 위에 쌓인 nn.functional 계층(활성화, 정규화, 손실 등)을 정리하고, 컨볼루션은 별도로 상세히 다룰 예정이다.

profile
Korea Univ. Computer Science & Engineering

0개의 댓글