if epoch == 10:
for param in model.parameters():
param.requires_grad =True
if (epoch % args.save_freq) ==0: # 중간에 이미지 save
for batch_idx, (X_batch, y_batch, *rest) in enumerate(valloader): # Validation Data
# print(batch_idx)
if isinstance(rest[0][0], str):
image_filename = rest[0][0]
else:
image_filename = '%s.png' % str(batch_idx + 1).zfill(3)
X_batch = Variable(X_batch.to(device='cuda')) # validation data의 x
y_batch = Variable(y_batch.to(device='cuda')) # validation data의 label
# start = timeit.default_timer()
y_out = model(X_batch) # 학습 결과
# stop = timeit.default_timer()
# print('Time: ', stop - start)
tmp2 = y_batch.detach().cpu().numpy() # validation data의 y
tmp = y_out.detach().cpu().numpy()
tmp[tmp>=0.5] = 1
tmp[tmp<0.5] = 0
tmp2[tmp2>0] = 1
tmp2[tmp2<=0] = 0
tmp2 = tmp2.astype(int)
tmp = tmp.astype(int)
# print(np.unique(tmp2))
yHaT = tmp
yval = tmp2
epsilon = 1e-20
del X_batch, y_batch,tmp,tmp2, y_out
yHaT[yHaT==1] =255
yval[yval==1] =255
fulldir = direc+"/{}/".format(epoch)
# print(fulldir+image_filename)
if not os.path.isdir(fulldir):
os.makedirs(fulldir)
cv2.imwrite(fulldir+image_filename, yHaT[0,1,:,:])
# cv2.imwrite(fulldir+'/gt_{}.png'.format(count), yval[0,:,:])
fulldir = direc+"/{}/".format(epoch)
torch.save(model.state_dict(), fulldir+args.modelname+".pth")
torch.save(model.state_dict(), direc+"/"+args.modelname+"_final_model.pth")
10개 epoch 당 1개씩 생성되는 거 = Validation set에 대해서임.