: 연속된 범위를 가진 차원을 하나의 텐서로 Flatten해준다.
< Example1 >
input = torch.randn(32,1,5,5) # torch.Size([32, 1, 5, 5]) m = nn.Flatten() # start_dim = 1, end_dim = -1 1 ~ 3(-1) 차원의 값을 하나로 줄이기 output = m(input) output.shape # torch.Size([32, 25]) e.g) 1x5x5 = 25
< Example2 >
#With non-default parameters input = torch.randn(32,1,5,5) m = nn.Flatten(0, 2) # start_dim =0 ~ end_dim = 2 output = m(input) output.shape # torch.Size([160, 5]) e.g) 32 x 1 x 5 = 160
< Image data >
input_image = torch.rand(64, 28, 28) # torch.Size([64,28,28]) flatten = nn.Flatten() flat_image = flatten(input_image) # torch.Size([64, 784])
reference : https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html
https://wikidocs.net/156984