wandb hyper-parameter tuning & Sweep

Jihoon·2022년 10월 31일
1

AI_CONTEST

목록 보기
4/7
post-thumbnail
post-custom-banner
  • 궁금한점이나 해석이 잘못된 부분이 있으면 언제든지 댓글로 말씀해주시면 감사하겠습니다!

1.Wandb란?

  • 정의 : Weights & Biases의 줄인말로써 실험 결과를 시각적으로 알려주고, hyper-parameter tuning의 역할을 수행할 수 있도록 도와주는 ML Experiment tracking tool 입니다.

1) Wandb 사용법

1. wandb login

pip install wandb

wandb login #(cmd창) -> 사이트 통해 API key 복붙하기 (참고로 붙여넣기할 때 key가 cmd에서 안보이는데 바로 enter눌러주면 해결됌)

2. wandb init() 설정 -> wandb.init()을 통해서 시작을 해줍니다.

wandb.init()

3. wandb의 configuration 설정 -> wandb.config 또는 wandb.config.update()를 통해서 config를 설정해줍니다.

wandb.config.update({
            "batch_size" : self.batchsize,
            "epochs": self.epochs,
        })

4. wandb.watch(model) -> watch함수를 통해서 각 layer에 전파되는 gradients 값을 확인할 수 있습니다.

5. wandb.log -> log를 wandb에 남기는 것으로써 Vanishing 문제와 Exploding 문제가 발생하면 즉각 발견할 수 있고 인터넷만 있다면 실시간으로 학습 log 확인이 가능합니다!

wandb.log({
            'rmse_score' : rmse_score
            })
  • 전체 Code
def train(self):
        ### wandb init
        wandb.init()
        ### config 설정
        wandb.config.update({
            "batch_size" : self.batchsize,
            "epochs": self.epochs,
        })
        
        for epoch in range(self.epochs):
            self.model.train()
            total_loss = 0
            tk0 = tqdm.tqdm(self.train_dataloader, smoothing=0, mininterval=1.0)
            
            for i, (fields, target) in enumerate(tk0):
                self.model.zero_grad()
                fields, target = fields.to(self.device), target.to(self.device)

                y = self.model(fields)
                loss = self.criterion(y, target.float())

                loss.backward()
                self.optimizer.step()
                ### watch
                wandb.watch(self.model)
                
                total_loss += loss.item()
                if (i + 1) % self.log_interval == 0:
                    tk0.set_postfix(loss=total_loss / self.log_interval)
                    total_loss = 0
            # rmse 계산
            rmse_score = self.predict_train()
            print('epoch:', epoch, 'validation rmse:', rmse_score)
            
            ### wandb logging 설정 
            wandb.log({
            'rmse_score' : rmse_score
            })

2. hyper-parameter tuning (Sweep)

1) sweep(스윕)이란?

  • 기본적으로 hyper parameter를 자동으로 최적화 해주는 Tool!! (wandb GUI)
  • 장점
    1. 자동으로 최적의 값을 찾아준다!
    1. hyper-parameter와 metric(rmse, auc, loss등)간의 상관관계를 시각화해줌!

2) Sweep vs Optuna

  1. 현재까지 둘 다 활용해보면서 느낀점은 시각화적인 측면에서는 wandb가 더 낫지 않나?라는 생각이 든다!
  2. 그러나, 자동으로 찾아서 저장한 best_parameters라는 기능이 Optuna에는 있지만, wandb에는 없는 것이 조금 아쉽다.(오늘 공부한 바론 그런데 다른 방법이 있으시면 꼭 알려주세요 ~!)

3) 사용법 (GUI)

  • template 상에서 code로 구현하고 싶었으나, 많은 오류로 인해 해결하지 못해서 GUI방식으로 구현하는 방법을 모색했고, GUI 방식으로 하는 방법을 공유하겠습니다. 다시 도전해서 python template에서 작동하도록 구현한다면 그때 다시 공유해볼게요.
  1. Wandb에 설치된 Project Folder를 클릭한다. 그러면 아래에 나와있는 sweep이라는 아이콘을 찾을 수 있습니다.
  1. 클릭해서 들어가면 오른쪽 위 상단에 Create Sweep을 눌러줍니다.

  2. Sweep Configuration 화면이 나올텐데 거기서 program 변수에는 main또는 train을 하는 파이썬 파일을 기입합니다. 그리고, 아래에는 각종 hyperParameters가 있는데 이 부분은 wandb documentation에서 확인할 수 있습니다. (원하는 하이퍼파라미터와 범위를 설정하세요!)

    여기서 주의할 점은 자신의 Template의 arg element와 동일해야 한다 주의!!!

  3. Config를 완성하셨으면 다음 Launch Agent에서 가운데 표시되는 코드를 CMD창에 복붙하면 끝~!

  4. 코드를 돌리고 나면 아래의 창에서 자신이 실행한 모델의 결과를 확인하고, 시각적으로 분석할 수 있습니다 !

profile
장난감이 데이터인 사람
post-custom-banner

2개의 댓글

comment-user-thumbnail
2024년 5월 28일

헉 이거 보고 쉽게 했어요!! ㅎㅎ 감사합니다

1개의 답글