[PyTorch] nn.Flatten()

qw4735·2023년 3월 9일
0

PyTorch

목록 보기
5/8

nn.Flatten()

: 연속된 범위를 가진 차원을 하나의 텐서로 Flatten해준다.

< Example1 >

input = torch.randn(32,1,5,5)  # torch.Size([32, 1, 5, 5])
m = nn.Flatten()  # start_dim =  1, end_dim = -1  1 ~ 3(-1) 차원의 값을 하나로 줄이기 
output = m(input)
output.shape # torch.Size([32, 25])   e.g) 1x5x5 = 25
  • Default : start_dim = 1 , end_dim = -1
    0번째 차원인 배치 크기는 유지하되, 나머지 다차원 데이터를 1차원으로 줄여주는 기능을 함

< Example2 >

#With non-default parameters
input = torch.randn(32,1,5,5)
m = nn.Flatten(0, 2)  # start_dim =0 ~ end_dim = 2 
output = m(input)
output.shape  # torch.Size([160, 5])  e.g) 32 x 1 x 5 = 160
  • No_default parameters : start_dim = 0, end_dim = 2
    마지막 차원(3)은 유지하고, 나머지 다차원 데이터를 1차원으로 줄여주는 기능을 함

< Image data >

input_image = torch.rand(64, 28, 28)   # torch.Size([64,28,28])
flatten = nn.Flatten()
flat_image = flatten(input_image)  # torch.Size([64, 784])
  • 사이즈가 28x28인 2차원 텐서가 flatten과정을 거쳐서 1차원(784)으로 바뀜
  • 결론적으로, nn.Flatten()을 통해, 배치크기가 64이고 사이즈가 28x28인 텐서가 배치크기는 64로 그대로 유지되고, 나머지 다차원 데이터는 1차원으로 줄어드는 것을 알 수 있다.

reference : https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html
https://wikidocs.net/156984

0개의 댓글