sweep 사용방법

Leejaegun·2024년 11월 18일

Python & etc

목록 보기
4/27

1. set_wandb

def set_wandb(configs):
    wandb.login(key=configs.wandb.api_key)
    # Sweep 활성화 여부를 확인
    if configs.wandb.use_sweep:
        # wandb sweep에서 전달된 파라미터를 가져옴
        config_sweep = wandb.config
        
        # Sweep에서 전달된 값으로 configs 업데이트 (수정된 부분)
        configs.train.lr = config_sweep.get("train_lr", configs.train.lr)  
        configs.train.train_batch_size = config_sweep.get("train_batch_size", configs.train.train_batch_size)  
        configs.max_epoch = config_sweep.get("max_epoch", configs.max_epoch)  
        configs.model.name = config_sweep.get("model_name", configs.model.name)  
        configs.model.encoder_name = config_sweep.get("model_encoder_name", configs.model.encoder_name)  
        configs.loss.name = config_sweep.get("loss_name", configs.loss.name)  
        configs.scheduler.name = config_sweep.get("scheduler_name", configs.scheduler.name)  
    wandb.init(
        entity=configs.wandb.team_name, #팀  wandb page생기면.
        project=configs.wandb.project_name,
        name=configs.wandb.exp_name, #진행하는 실험의 이름? 뭔지 모르겠음.
        config={
                'model': configs.model.name,
                'resize': configs.image_size,
                'batch_size': configs.train.train_batch_size,
                'loss_name': configs.loss.name,
                'scheduler_name': configs.scheduler.name,
                'learning_rate': configs.train.lr,
                'epoch': configs.max_epoch
            }
    )

2. config.yaml에서 설정

# wandb
wandb:
  api_key: ##본인의 api 키 적으시면 됩니다.
  team_name: CV_SOTA
  project_name: "segmentation_project2"
  experiment_detail: 진행하는 실험의 이름
  exp_name: *model_name
  use_sweep: True 
  sweeep_path: "configs/train_configs/train/sweep.yaml" 

3. config_sweep.yaml설정

program: train.py
name: segmentation_sweep
method: bayes
command:
  - /home/jaegun/miniconda3/envs/AI_tech/bin/python  # Python 실행 경로(which python 을 통해 본인의 파이썬 실행경로적으셈)
  - ${program}  # train.py 실행할 프로그램 이름입니다. 위에서 program: train.py로 지정되어 있으므로, 실제로는 train.py가 여기에 대입됩
  - --config=configs/config.yaml  # config 파일 경로

metric:
  goal: maximize
  name: dice


parameters:
  train_lr:  # train 관련 하이퍼파라미터 (lr)
    distribution: uniform
    max: 0.1
    min: 0.0
  train_batch_size:  # train 관련 하이퍼파라미터 (batch size)
    values: [4, 8, 16]
  max_epoch:  # train 관련 하이퍼파라미터 (epoch)
    min: 5
    max: 15

  model_name:  # model 이름
    values: ["Unet", "DeepLabV3", "UnetPlusPlus"]
  model_encoder_name:  # model encoder
    values: ["efficientnet-b7", "efficientnet-b4", "resnet101", "resnet50"]

  loss_name:  # loss 관련 하이퍼파라미터
    values: ["BCEWithLogitsLoss", "CombinedLoss", "DiceLoss"]

  scheduler_name:  # scheduler 관련 하이퍼파라미터
    values: ["CosineAnnealingLR", "MultiStepLR", "ReduceLROnPlateau"]

4. sweep 시작 명령

wandb sweep configs/config_sweep.yaml

5. 에러가 생기면?

-> init으로 먼저 열어준다.
그 후에 하면 됩니다.

def set_wandb(configs):
    wandb.login(key=configs.wandb.api_key)
    
    wandb.init(
        entity=configs.wandb.team_name, #팀  wandb page생기면.
        project=configs.wandb.project_name,
        name=configs.wandb.exp_name, #진행하는 실험의 이름? 뭔지 모르겠음.
        config={
                'model': configs.model.name,
                'resize': configs.image_size,
                'batch_size': configs.train.train_batch_size,
                'loss_name': configs.loss.name,
                'scheduler_name': configs.scheduler.name,
                'learning_rate': configs.train.lr,
                'epoch': configs.max_epoch
            }
    )
    # Sweep 활성화 여부를 확인
    if configs.wandb.use_sweep:
        # wandb sweep에서 전달된 파라미터를 가져옴
        config_sweep = wandb.config
        
        # Sweep에서 전달된 값으로 configs 업데이트 (수정된 부분)
        configs.train.lr = config_sweep.get("train_lr", configs.train.lr)  
        configs.train.train_batch_size = config_sweep.get("train_batch_size", configs.train.train_batch_size)  
        configs.max_epoch = config_sweep.get("max_epoch", configs.max_epoch)  
        configs.model.name = config_sweep.get("model_name", configs.model.name)  
        configs.model.parameters.encoder_name = config_sweep.get("model_encoder_name", configs.model.parameters.encoder_name)  
        configs.loss.name = config_sweep.get("loss_name", configs.loss.name)  
        configs.scheduler.name = config_sweep.get("scheduler_name", configs.scheduler.name)  

config_sweep.yaml에서는

program: train.py
name: segmentation_sweep
method: bayes
command:
  - /home/jaegun/miniconda3/envs/AI_tech/bin/python  # Python 실행 경로(which python 을 통해 본인의 파이썬 실행경로적으셈)
  - ${program}  # train.py 실행할 프로그램 이름입니다. 위에서 program: train.py로 지정되어 있으므로, 실제로는 train.py가 여기에 대입됩
  - --config=configs/config.yaml  # config 파일 경로

metric:
  goal: maximize
  name: dice


parameters:
  train_lr:  # train 관련 하이퍼파라미터 (lr)
    distribution: uniform
    max: 0.1
    min: 0.0
  train_batch_size:  # train 관련 하이퍼파라미터 (batch size)
    values: [4, 8, 16]
  max_epoch:  # train 관련 하이퍼파라미터 (epoch)
    min: 5
    max: 15

  model_name:  # model 이름
    values: ["Unet", "DeepLabV3", "UnetPlusPlus"]
  model_encoder_name:  # model encoder
    values: ["efficientnet-b7", "efficientnet-b4", "resnet101", "resnet50"]

  loss_name:  # loss 관련 하이퍼파라미터
    values: ["BCEWithLogitsLoss", "CombinedLoss", "DiceLoss"]

  scheduler_name:  # scheduler 관련 하이퍼파라미터
    values: ["CosineAnnealingLR", "MultiStepLR", "ReduceLROnPlateau"]

이런식으로 model_name 밑에 바로 values 적어줘야 함

그리고 sweep 활성화 부분을 자세히 살펴보면

if configs.wandb.use_sweep:
        # wandb sweep에서 전달된 파라미터를 가져옴
        config_sweep = wandb.config
        
        # Sweep에서 전달된 값으로 configs 업데이트 (수정된 부분)
        configs.train.lr = config_sweep.get("train_lr", configs.train.lr)  
        configs.train.train_batch_size = config_sweep.get("train_batch_size", configs.train.train_batch_size)  
        configs.max_epoch = config_sweep.get("max_epoch", configs.max_epoch)  
        configs.model.name = config_sweep.get("model_name", configs.model.name)  
        configs.model.parameters.encoder_name = config_sweep.get("model_encoder_name", configs.model.parameters.encoder_name)  
        configs.loss.name = config_sweep.get("loss_name", configs.loss.name)  
        configs.scheduler.name = config_sweep.get("scheduler_name", configs.scheduler.name)

configs가 내가 원래 있는 config.yaml이고 이를 config_sweep.get에 loss_name, model_name 등등 get해서 values에 넣는것이다.

profile
Lee_AA

0개의 댓글