에러 메세지
- ucTransNet → input 파일 이미지 224 외에 안 되는문제
class UCTransNet(nn.Module):
def __init__(self, config,n_channels=3, n_classes=1,img_size=224,vis=False):
super().__init__()
self.vis = vis
self.n_channels = n_channels
self.n_classes = n_classes
in_channels = config.base_channel
self.inc = ConvBatchNorm(n_channels, in_channels)
self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2)
self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2)
self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2)
self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2)
self.mtc = ChannelTransformer(config, vis, img_size,
channel_num=[in_channels, in_channels*2, in_channels*4, in_channels*8],
patchSize=config.patch_sizes)
self.up4 = UpBlock_attention(in_channels*16, in_channels*4, nb_Conv=2)
self.up3 = UpBlock_attention(in_channels*8, in_channels*2, nb_Conv=2)
self.up2 = UpBlock_attention(in_channels*4, in_channels, nb_Conv=2)
self.up1 = UpBlock_attention(in_channels*2, in_channels, nb_Conv=2)
self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1), stride=(1,1))
self.last_activation = nn.Sigmoid() # if using BCELoss
def forward(self, x):
x = x.float()
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x1,x2,x3,x4,att_weights = self.mtc(x1,x2,x3,x4)
x = self.up4(x5, x4)
x = self.up3(x, x3)
x = self.up2(x, x2)
x = self.up1(x, x1)
if self.n_classes ==1:
logits = self.last_activation(self.outc(x))
else:
logits = self.outc(x) # if nusing BCEWithLogitsLoss or class>1
if self.vis: # visualize the attention maps
return logits, att_weights
else:
return logits
if model_type == 'UCTransNet':
config_vit = config.get_CTranS_config()
logger.info('transformer head num: {}'.format(config_vit.transformer.num_heads))
logger.info('transformer layers num: {}'.format(config_vit.transformer.num_layers))
logger.info('transformer expand ratio: {}'.format(config_vit.expand_ratio))
model = UCTransNet(config_vit,n_channels=config.n_channels,n_classes=config.n_labels)
-> UCTransNet 의 img_size를 256으로 고치고, Load_dataset부분의 224 fix된값도 고치니 해결됨!
class UCTransNet(nn.Module):
#def __init__(self, config,n_channels=3, n_classes=1,img_size=224,vis=False):
class UCTransNet(nn.Module):
def __init__(self, config,n_channels=3, n_classes=1,img_size=256,vis=False):
으로 수정.