출처 : https://www.slideshare.net/yongho/ss-79607172
정확하게 이해를 못했다.. 이런 것이 있다는 것만 알아두고 넘어가야겠다.
linear = torch.nn.Linear(784, 256, bias=True)
torch.nn.init.xavier_uniform_(linear.weight)
linear1 = torch.nn.Linear(784, 512, bias=True)
linear2 = torch.nn.Linear(512, 512, bias=True)
linear3 = torch.nn.Linear(512, 512, bias=True)
linear4 = torch.nn.Linear(512, 10, bias=True)
relu = torch.nn.ReLU()
dropout = torch.nn.Dropout(p=drop_prob) # p는 확률(몇 %를 drop할 지)
model = torch.nn.Sequential(linear1, relu, dropout,
linear2, relu, dropout,
linear3, relu, dropout,
linear4)
model.train() # dropout을 사용(학습)
model.eval() # dropout을 사용X(평가)
linear1 = torch.nn.Lineare(784, 32, bias=True)
linear2 = torch.nn.Linear(32, 32, bias=True)
linear3 = torch.nn.Linear(32, 10, bias=True)
relu = torch.nn.ReLU()
bn1 = torch.nn.BatchNorm1d(32)
bn2 = torch.nn.Batchnorm1d(32)
bn_model = torch.nn.Sequential(linear1, bn1, relu,
linear2, bn2, relu,
linear3)
bn_model.train()