config = {'seed': 2021,
          'root': './', 
          'n_splits': 5,
          'epoch': 20,
          'trainer': {
              'gpus': 1,
              'accumulate_grad_batches': 1,
            #  'progress_bar_refresh_rate': 1,
              'fast_dev_run': False,
              'num_sanity_val_steps': 0,
              'resume_from_checkpoint': None,
          },
          'transform':{
              'name': 'get_default_transforms',
              'image_size': 224
          },
          'train_loader':{
              'batch_size': 64,
              'shuffle': True,
              'num_workers': 4,
              'pin_memory': False,
              'drop_last': True,
          },
          'val_loader': {
              'batch_size': 64,
              'shuffle': False,
              'num_workers': 4,
              'pin_memory': False,
              'drop_last': False
         },
          'test_loader': {
                'batch_size': 8,
                'shuffle': False,
                'num_workers': 4,
                'pin_memory': False,
                'drop_last': False,
                },
          'model':{
              'name': 'swin_tiny_patch4_window7_224',
              #'name': 'swin_large_patch4_window7_224',
              'output_dim': 1
          },
          'optimizer':{
              'name': 'optim.AdamW',
              'params':{
                  'lr': 1e-4  #  수정
              },
          },
          'scheduler':{
              'name': 'optim.lr_scheduler.CosineAnnealingWarmRestarts',
              'params':{
                  'T_0': 20,
                  'eta_min': 1e-5,   # 수정
              }
          },
          'loss': 'nn.BCEWithLogitsLoss',
}

config = Box(config)

1. Box의 역할은?

Box는 dict를 object처럼 사용할 수 있게 해주는 라이브러리입니다. dict에서 key로만 접근하는 것이 아니라, attribute로도 접근할 수 있게 해줍니다. 이렇게 attribute로 접근하면, 해당 attribute가 없을 때 에러가 발생하는 대신 None을 반환합니다. 또한, dict에 nested structure가 있을 때도 쉽게 접근하고 변경할 수 있습니다. Box는 dot notation을 사용하여 dict에 접근할 수 있도록 해주는데, 이를 통해 코드 가독성이 좋아지고, 코드 작성이 더 편리해집니다.

2. nested structure가 뭐야?

Nested structure는 내부적으로 다른 객체나 자료형을 가지고 있는 객체를 의미합니다. 즉, 중첩된 구조를 갖는 것을 말합니다.
예를 들어, 리스트 내부에 딕셔너리가 있는 경우, 이는 nested structure라고 할 수 있습니다. 또한, 클래스 안에 다른 클래스를 멤버로 가지고 있는 경우에도 nested structure로 볼 수 있습니다.
딥러닝에서는 모델의 레이어, 레이어 안의 파라미터 등이 nested structure를 이루는 경우가 많습니다. 이러한 nested structure를 잘 다루면 코드를 더 간결하고 효율적으로 작성할 수 있습니다.

3. gradient accumulation이 뭐야?

Gradient Accumulation은 매 batch마다 backpropagation과 optimizer update를 수행하는 것이 아니라, n개의 batch의 gradient를 누적하여 한번에 backpropagation과 optimizer update를 수행하는 방식입니다.
이 방식은 GPU 메모리가 제한적인 경우에 유용하며, 큰 batch size를 사용하여 훈련시키는 것과 유사한 효과를 얻을 수 있습니다. 예를 들어, batch size가 32인 경우 4개의 batch를 누적하면 batch size가 128과 유사한 효과를 볼 수 있습니다. 또한, gradient accumulation을 사용하면 훈련 과정에서 불안정한 gradient 문제를 완화할 수 있으며, 학습 속도와 일관성을 유지하면서 큰 모델을 더 효과적으로 학습시킬 수 있습니다.

4. num_sanity_val_steps 가 0이면 어떻게 되는거야?

num_sanity_val_steps가 0이면, 모델이 최초의 validation step에서 예외가 발생하더라도 학습이 중지되지 않고 계속 진행됩니다. 즉, 모델이 학습을 시작할 때 최초의 validation step에서 문제가 발생하는 경우, 해당 문제를 보고하고 학습을 중단하는 대신 계속해서 학습을 진행합니다. 이것은 주로 디버깅이나 실험 초기 단계에서 모델이 정상적으로 작동하는지 확인하기 위해 사용됩니다. 하지만 일반적으로, num_sanity_val_steps를 0으로 설정하면 디버그 모드에서 모델을 테스트하는 것이 좋습니다.

5. BCEWithLogitsLoss에 대해 설명해줘

BCEWithLogitsLoss는 이진 분류 문제에서 주로 사용되는 손실 함수 중 하나입니다. BCE는 Binary Cross Entropy의 약자이며, 로짓값을 사용하므로 BCEWithLogitsLoss라고도 부릅니다.
BCEWithLogitsLoss는 각 샘플에 대한 예측값과 정답 레이블 간의 로그 소프트맥스 함수의 비용을 계산합니다. 이때, 소프트맥스 함수를 사용하면 예측값이 [0, 1] 범위에 있으며, 이진 분류 문제에서 사용하기 적합합니다.
BCEWithLogitsLoss는 예측값이 Sigmoid 함수를 거치기 이전의 값인 로짓값을 입력으로 받습니다. 이진 분류 문제에서는 출력 뉴런이 하나인 출력층에 Sigmoid 함수를 적용하여 0 또는 1에 가까운 값으로 변환한 예측값을 얻습니다. 로짓값을 사용하면 Sigmoid 함수의 출력값이 0 또는 1에 가까워지는 부분에서 기울기가 0에 가까워져서 학습이 더 어렵게 되는 문제를 해결할 수 있습니다.
BCEWithLogitsLoss는 모델이 예측한 값과 정답 레이블을 비교하여 둘 간의 손실을 계산합니다. 이 손실은 모델의 가중치를 업데이트할 때 사용되며, 모델이 예측을 개선하도록 학습됩니다.

6. torch.autograd.set_detect_anomaly(True)

torch.autograd.set_detect_anomaly(True)는 PyTorch에서 계산 그래프에서 오류를 검출하는 기능을 활성화하는 함수입니다.
이 함수를 호출하면, 계산 그래프에서 연산 중에 NaN이나 Inf 값 등의 에러가 발생할 경우, 해당 위치의 정보와 스택 트레이스를 출력합니다. 이를 통해 어디서 에러가 발생했는지를 확인하고, 이를 수정할 수 있습니다. 즉, 딥러닝 모델의 학습 중에 발생할 수 있는 수치적 안정성 문제를 검출하기 위한 디버깅 기능으로 사용됩니다.

profile
ML/DL swimmer

0개의 댓글