einsum(): operands do not broadcast with remapped shapes [original->remapped]:

boingboing·2024년 6월 25일

에러 로그

RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [32, 8, 1, 16]->[32, 8, 16, 1, 1][1, 64, 64]->[1, 1, 64, 64, 1]

개념

  • einsum 연산 -> 행렬, 벡터의 내적, 외적, 전치, 행렬곱 등을 일관성있게 표현할 수 있음.

  • EinSum : Einstein Summation Convention

  • 특정 index 집합에 대한 합(sigma) 연산을 간결하게 표시 함.

  • 왼쪽은 operand들의 차원을 나열함. "," 를 기준으로 구분.
  • 오른쪽은 출력값의 차원 인덱스를 나타냄.
  • 출력값에 표현되지 않은 인덱스들은 operand들을 곱한 후 해당 인덱스를 기준으로 더해진다고 함.

발생 위치

elif modelname == "logo":
    model = lib.models.axialnet.logo(img_size = imgsize, imgchan = imgchant)
    
def logo(pretrained=False, **kwargs):
    model = medt_net(AxialBlock,AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs)
    return model
  • 여기서 AxialBlock 사용

  • AxialBlock에의 init 부분에서 AxialAttention 사용


class AxialBlock(nn.Module):
    expansion = 2

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, kernel_size=56):
        super(AxialBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.))
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv_down = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.hight_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size)
        self.width_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
        self.conv_up = conv1x1(width, planes * self.expansion)
        self.bn2 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
  • AxialAttention의 forward 부분에서
class AxialAttention(nn.Module):

    def forward(self, x):
        # Transformations
        qkv = self.bn_qkv(self.qkv_transform(x))
        q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)

        # Calculate position embedding
        all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
        q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
        
        qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
        kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
        
        qk = torch.einsum('bgci, bgcj->bgij', q, k)

einsum 연산 설명

bgci,cij->bgij

  • bgci : q 텐서의 차원 이름(batch, group, channels, input )

  • cij : q_embedding 텐서의 차원 이름(channels, kernel size, kernel size)

  • bgij : 출력 텐서의 차원이름 (batch, group, input, kernelsize)

    q의 마지막 두 차원인 c와 i를 q_embedding의 첫 번째와 두 번째 차원인 c와 i와 곱하고, 그 결과를 b,g,i,j 차원으로 반환함.

왜 발생했는가?

런타임 오류: einsum(): 피연산자가 리매핑된 shape [original->remapped]으로 브로드캐스트되지 않습니다:

        Query               [32, 8, 1, 16]->[32, 8, 16, 1, 1] 
        Query Embedding     [1, 64, 64]   ->[1, 1, 64, 64, 1]

remapping은 왜 하는거지..?

Similar Situation

  • 똑같은 AxialBlock을 사용한 AxialUnet을 돌려봄-> 잘 돌아감.
  • LoGO도 모든 케이스가 안되는게 아니라 잘 되다가 안 됨.
  • 똑같은 쿼리에서 쿼리 임베딩 벡터 똑같은 방식으로 추출해서 einsum하는건데요..^^

einsum 연산에서 query랑 query embedding의 einsum 실행 시 shape을 봄

-------Axial Attention----------------------
query shape : torch.Size([64, 8, 8, 64])
q_embedding shape :  torch.Size([8, 64, 64])

-> query의 마지막 두 차원(8, 64) 가 query embedding의 첫 번째와 두 번째 차원(8, 64)와 같음 !

-> LoGo에서도 axiaulunet처럼 두 개의 차원을 같게 만들어주면 됨.

참고자료

https://baekyeongmin.github.io/dev/einsum/

0개의 댓글