argparse에 대해

김태우·2025년 4월 5일

인공지능 개념

목록 보기
2/3

argparse란?

CLI에서 스크립트를 실행할때, 인자(argument)의 기본값(default)을 재정의할때 쓰이는 메서드이다.

인공지능 코드로는, 다음과 같은 예시로 활용해볼 수 있다.

import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

def main(args):
    # 1) 간단한 예시용 Dataset / Dataloader 구성
    # 실제로는 args.dataset_path 등을 사용해서 로딩
    
    if args.verbose:
        print(f"[INFO] Loading dataset from '{args.dataset_path}' ... (예시 코드, 실제 로직 아님)")
    
    # 예시: 100개의 10차원 가짜 데이터 + 2차원 라벨
    X = torch.randn(100, 10)
    y = torch.randint(0, 2, (100,))
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    
    # 2) 간단한 모델 정의 (예시)
    
    model = nn.Sequential(
        nn.Linear(10, 16),
        nn.ReLU(),
        nn.Linear(16, 2)
    )
    # device 설정 (cpu 또는 cuda)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    
    # 3) 손실 함수와 옵티마이저
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    
    
    # 4) 학습 루프
    
    for epoch in range(args.epochs):
        total_loss = 0.0
        for batch_x, batch_y in dataloader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            
            # Forward
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            
            # Backward & Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        if args.verbose:
            print(f"Epoch [{epoch+1}/{args.epochs}] - Loss: {total_loss/len(dataloader):.4f}")
    
    
    # 5) 모델 저장
    
    if args.save_model:
        torch.save(model.state_dict(), args.save_model)
        if args.verbose:
            print(f"[INFO] Model saved to '{args.save_model}'")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Simple Training Script Example")
    
    parser.add_argument("--epochs", type=int, default=10,
                        help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=32,
                        help="Batch size for training")
    parser.add_argument("--lr", type=float, default=1e-3,
                        help="Learning rate")
    parser.add_argument("--dataset_path", type=str, default="data/train",
                        help="Path to training dataset (placeholder for example)")
    parser.add_argument("--device", type=str, default="cuda:0",
                        help="Device to use for training, e.g. 'cpu' or 'cuda:0'")
    parser.add_argument("--save_model", type=str, default="model.pth",
                        help="Path to save the trained model file (e.g., 'model.pth')")
    parser.add_argument("--`", action="store_true",
                        help="If set, print detailed logs during training")
    
    args = parser.parse_args()
    main(args)
profile
학부생의 성장 일기, 관심있는 분야:NLP

0개의 댓글