Deep Learning 모델 중 Vision Tranformer 모델을 구현하는 중 이미지를 패치로 나누는 것이 감이 안잡혀 다른 사람의 코드를 참고하려고 하는데 다들 Einops라는 라이브러리를 사용한다.
이미지를 패치로 나누는 것 하나 때문에 라이브러리를 따로 설치하고 불러오는 과정을 굳이 해야하나 싶어서 PyTorch로 구현해보았다.
Patchify는 이미지를 일정 크기로 잘라내는 것을 말하며 위와 같이 작동한다.
우선 Tensor 형식으로 임시 이미지를 생성하기 위해 다음과 같이 코드를 작성합니다.
c= 2 # 채널의 수
p = 2 # 패치의 크기
h = 4 # 이미지의 높이
w = 4 # 이미지의 너비
img = torch.tensor(list(range(c*h*w)),dtype=torch.float32).reshape(c,h,w)+1 # 이미지를 c,h,w 형식으로 선언
이렇게 되면 예시 이미지와 동일한 이미지가 생성됩니다.
우선 패치로 나누기 위해 이미지를 패치의 크기에 맞게 reshape해줍니다.
img = img.reshape(c,h//p,p,w//p,p) # c,h//p,p,w//p,p
> tensor([[[[[ 1., 2.], <- 1번
[ 3., 4.]],
[[ 5., 6.], <- 2번
[ 7., 8.]]],
[[[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.]]]],
[[[[17., 18.],
[19., 20.]],
[[21., 22.],
[23., 24.]]],
[[[25., 26.],
[27., 28.]],
[[29., 30.],
[31., 32.]]]]])
우리의 목적은 1,2 다음에 3,4가 아닌 5,6이 오도록 해야합니다.
이를 위해 데이터가 읽히는 순서를 변경하기 위해서 Transpose 함수를 적용해줍니다.
1,2는 c,h//p,p,w//p,p 에서 4번째 w//p의 차원에 담겨있으며 1,2와 5,6을 구분하는 차원은 3번째 p이므로 transpose(2,3)을 적용해줍니다.
out = out.transpose(2,3) # c,h//p,w//p,p,p
> tensor([[[[[ 1., 2.],
[ 5., 6.]],<- 1번
[[ 3., 4.],
[ 7., 8.]]],
[[[ 9., 10.],
[13., 14.]],
[[11., 12.],
[15., 16.]]]],
[[[[17., 18.],
[21., 22.]],<- 2번
[[19., 20.],
[23., 24.]]],
[[[25., 26.],
[29., 30.]],
[[27., 28.],
[31., 32.]]]]])
위에서 보시는 것처럼 잘 묶인 것을 볼 수 있습니다.
또한 여기서 1번 위치에 있는 1,2,5,6 다음에는 2번 위치에 있는 17,18,21,22가 들어와야 합니다.
1,2,5,6은 c,h//p,w//p,p,p 3번째 w//p에 담겨 있으며 1,2,5,6과 17,18,21,22을 구분하는 것은 1번째 c이므로 transpose(0,2)를 적용해줍니다.
out = out.transpose(0,2) # w//p,h//p,c,p,p
> tensor([[[[[ 1., 2.],
[ 5., 6.]],
[[17., 18.],
[21., 22.]]],<- 1번
[[[ 9., 10.],
[13., 14.]],
[[25., 26.],
[29., 30.]]]],
[[[[ 3., 4.],
[ 7., 8.]],
[[19., 20.],
[23., 24.]]],<- 2번
[[[11., 12.],
[15., 16.]],
[[27., 28.],
[31., 32.]]]]])
이번에도 잘 나누어지는 것을 볼 수 있습니다.
하지만 1번에 있는 1,2,5,6,17,18,21,22 다음에는 2번에 있는 3,4,7,8,19,20,23,24가 와야 합니다.
1,2,5,6,17,18,21,22 는 w//p,h//p,c,p,p에서 2번째 h//p에 담겨있고 1,2,5,6,17,18,21,22와 3,4,7,8,19,20,23,24 는 1번째 w//p로 구분되므로 Transpose(0,1)를 적용해줍니다.
out = out.transpose(0,1) # h//p,w//p,c,p,p
> tensor([[[[[ 1., 2.],
[ 5., 6.]],
[[17., 18.],
[21., 22.]]],
[[[ 3., 4.],
[ 7., 8.]],
[[19., 20.],
[23., 24.]]]],
[[[[ 9., 10.],
[13., 14.]],
[[25., 26.],
[29., 30.]]],
[[[11., 12.],
[15., 16.]],
[[27., 28.],
[31., 32.]]]]])
원하는대로 잘 나누어졌습니다.
out = img.reshape(c,h//p,p,w//p,p) # c,h//p,p,w//p,p
out = out.transpose(2,3) # h//p,w//p,p,p
out = out.transpose(0,2) # w//p,h//p,c,p,p
out = out.transpose(0,1) # h//p,w//p,c,p,p
위 코드를 Permute를 이용해 요약하면 다음과 같습니다.
out = img.reshape(c,h//p,p,w//p,p) # c,h//p,p,w//p,p
out = out.permute(1,3,0,2,4) # h//p,w//p,c,p,p
원본
tensor([[[ 1., 2., 3., 4., 5., 6.],
[ 7., 8., 9., 10., 11., 12.],
[13., 14., 15., 16., 17., 18.],
[19., 20., 21., 22., 23., 24.],
[25., 26., 27., 28., 29., 30.],
[31., 32., 33., 34., 35., 36.]],
[[37., 38., 39., 40., 41., 42.],
[43., 44., 45., 46., 47., 48.],
[49., 50., 51., 52., 53., 54.],
[55., 56., 57., 58., 59., 60.],
[61., 62., 63., 64., 65., 66.],
[67., 68., 69., 70., 71., 72.]]])
결과
tensor([[[[[ 1., 2., 3.],
[ 7., 8., 9.],
[13., 14., 15.]],
[[37., 38., 39.],
[43., 44., 45.],
[49., 50., 51.]]],
[[[ 4., 5., 6.],
[10., 11., 12.],
[16., 17., 18.]],
[[40., 41., 42.],
[46., 47., 48.],
[52., 53., 54.]]]],
[[[[19., 20., 21.],
[25., 26., 27.],
[31., 32., 33.]],
[[55., 56., 57.],
[61., 62., 63.],
[67., 68., 69.]]],
[[[22., 23., 24.],
[28., 29., 30.],
[34., 35., 36.]],
[[58., 59., 60.],
[64., 65., 66.],
[70., 71., 72.]]]]])
원본
tensor([[[ 1., 2., 3., 4., 5., 6., 7., 8.],
[ 9., 10., 11., 12., 13., 14., 15., 16.],
[17., 18., 19., 20., 21., 22., 23., 24.],
[25., 26., 27., 28., 29., 30., 31., 32.]],
[[33., 34., 35., 36., 37., 38., 39., 40.],
[41., 42., 43., 44., 45., 46., 47., 48.],
[49., 50., 51., 52., 53., 54., 55., 56.],
[57., 58., 59., 60., 61., 62., 63., 64.]],
[[65., 66., 67., 68., 69., 70., 71., 72.],
[73., 74., 75., 76., 77., 78., 79., 80.],
[81., 82., 83., 84., 85., 86., 87., 88.],
[89., 90., 91., 92., 93., 94., 95., 96.]]])
결과
tensor([[[[[ 1., 2.],
[ 9., 10.]],
[[33., 34.],
[41., 42.]],
[[65., 66.],
[73., 74.]]],
[[[ 3., 4.],
[11., 12.]],
[[35., 36.],
[43., 44.]],
[[67., 68.],
[75., 76.]]],
[[[ 5., 6.],
[13., 14.]],
[[37., 38.],
[45., 46.]],
[[69., 70.],
[77., 78.]]],
[[[ 7., 8.],
[15., 16.]],
[[39., 40.],
[47., 48.]],
[[71., 72.],
[79., 80.]]]],
[[[[17., 18.],
[25., 26.]],
[[49., 50.],
[57., 58.]],
[[81., 82.],
[89., 90.]]],
[[[19., 20.],
[27., 28.]],
[[51., 52.],
[59., 60.]],
[[83., 84.],
[91., 92.]]],
[[[21., 22.],
[29., 30.]],
[[53., 54.],
[61., 62.]],
[[85., 86.],
[93., 94.]]],
[[[23., 24.],
[31., 32.]],
[[55., 56.],
[63., 64.]],
[[87., 88.],
[95., 96.]]]]])
def patchify(x,p):
c,h,w = x.shape
out = img.reshape(c,h//p,p,w//p,p) # c,h//p,p,w//p,p
out = out.permute(1,3,0,2,4) # h//p,w//p,c,p,p
return out
좋은 글 감사합니다. 자주 올게요 :)