
정리:
이 구조는 LLM의 지속학습 실험에 필요한 모든 핵심 기능(특징 추출, 망각 방지, 리플레이, 동적 배치, 체크포인트, 다양한 코드 자동완성 포맷 등)을 체계적으로 지원하도록 설계된 것이 특징이다.
각 클래스의 책임이 명확히 분리되어 있어, 실험 목적에 따라 독립적으로 교체·확장하기 용이하다.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CapaBoost + Continual Learning 통합 시스템
최신 연구 결과 기반 고성능 지속 학습 시스템
ICLR 2024 CapaBoost + MER + EWC + NCM + SCR
Created: June 2025
Authors: [James Kim]
실행 방법:
python cc_train.py --config config.yaml
또는
./run_training.sh
"""
import os
import sys
import json
import random
import logging
import argparse
import numpy as np
import time
import signal
import psutil
import GPUtil
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union, Any
from collections import defaultdict, deque
import datetime
from pathlib import Path
import warnings
from tqdm import tqdm
import yaml
# datasets 관련 import는 logger 초기화 후에 수행
# Transformers 관련 import 추가
# Transformers 관련 import 추가 (logger 초기화 후에 사용하기 위해 아래로 이동)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, Dataset, Subset
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
# 데이터셋 관련 import
from datasets import load_dataset
# Transformers 관련 import 추가
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
try:
from tensorboardX import SummaryWriter
except ImportError:
# TensorBoard를 사용할 수 없을 때 더미 클래스 제공
class SummaryWriter:
def __init__(self, *args, **kwargs):
logger.warning("TensorBoard와 TensorboardX 둘 다 설치되지 않았습니다. 로깅이 비활성화됩니다.")
def add_scalar(self, *args, **kwargs):
pass
def add_scalars(self, *args, **kwargs):
pass
def add_figure(self, *args, **kwargs):
pass
def close(self, *args, **kwargs):
pass
# 중요도에 따라 경고 필터링
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, message="torch.cuda.*_rng_state")
warnings.filterwarnings("ignore", category=UserWarning, message="The parameter 'pretrained' is deprecated")
# 로깅 설정
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('cc_train.log')
]
)
logger = logging.getLogger(__name__)
# Transformers 관련 import 추가
try:
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, PeftModel
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
logger.warning("transformers 또는 PEFT 라이브러리가 설치되지 않았습니다. 실제 모델 로드가 불가능합니다.")
# Datasets 관련 import 추가
try:
from datasets import load_dataset
DATASETS_AVAILABLE = True
except ImportError:
DATASETS_AVAILABLE = False
logger.warning("datasets 라이브러리가 설치되지 않았습니다. 데이터셋 로드가 불가능합니다.")
# 전역 시드 고정 함수
def set_seed(seed: int):
"""완전한 재현 가능성을 보장하기 위한 시드 설정"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # 멀티 GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(seed)
logger.info(f"모든 랜덤 시드가 {seed}로 고정되었습니다.")
# AWS Spot Instance 중단 감지 클래스
class SpotInterruptionHandler:
"""AWS Spot Instance 중단 감지 및 대응 핸들러"""
def __init__(self, checkpoint_dir: str = "checkpoints", checkpoint_freq: int = 100):
self.checkpoint_dir = checkpoint_dir
self.checkpoint_freq = checkpoint_freq
self.interrupted = False
self.last_checkpoint_time = time.time()
self.setup_interruption_detector()
logger.info("Spot Instance 중단 감지기가 초기화되었습니다.")
def setup_interruption_detector(self):
"""중단 신호 감지 설정"""
os.makedirs(self.checkpoint_dir, exist_ok=True)
# SIGTERM 시그널 처리기 등록 (AWS Spot 중단 신호)
signal.signal(signal.SIGTERM, self.interruption_handler)
# 수동 체크 메서드를 위한 추가 설정
self.metadata_url = "http://169.254.169.254/latest/meta-data/spot/instance-action"
def interruption_handler(self, signum, frame):
"""중단 시그널 처리"""
self.interrupted = True
logger.warning("Spot Instance 중단 신호가 감지되었습니다!")
# 여기서는 플래그만 설정하고, 실제 체크포인트 저장은 메인 학습 루프에서 처리
def check_interruption(self):
"""EC2 메타데이터를 통한 중단 예정 확인"""
try:
import requests
response = requests.get(self.metadata_url, timeout=0.1)
if response.status_code == 200:
self.interrupted = True
logger.warning(f"Spot 중단 임박! 메타데이터: {response.json()}")
return True
except Exception:
# 중단 예정이 아니거나 메타데이터 접근 실패 시 무시
pass
return self.interrupted
def should_checkpoint(self, step: int) -> bool:
"""체크포인트 저장 필요 여부 확인"""
# 중단 감지 시 즉시 체크포인트
if self.check_interruption():
return True
# 정기적인 체크포인트
if step % self.checkpoint_freq == 0:
return True
# 시간 기반 체크포인트 (10분마다)
current_time = time.time()
if current_time - self.last_checkpoint_time > 600: # 10분
self.last_checkpoint_time = current_time
return True
return False
def start_monitoring(self, model=None, optimizer=None, scheduler=None, scaler=None, trainer=None):
"""스팟 인스턴스 중단 모니터링 시작 - test_environment.py와의 호환성을 위한 덧방화 함수"""
# test_environment.py에서 스팟 인스턴스 모니터링을 처리하민로 필요한 파라미터만 저장
if model is not None:
self.model = model
if optimizer is not None:
self.optimizer = optimizer
if scheduler is not None:
self.scheduler = scheduler
if scaler is not None:
self.scaler = scaler
if trainer is not None:
self.trainer = trainer
logger.info("스팟 인스턴스 체크포인트 상태 초기화 완료 (test_environment.py에서 모니터링 수행)")
def stop_monitoring(self):
"""스팟 인스턴스 중단 모니터링 종료 - test_environment.py와의 호환성을 위한 덧방화 함수"""
logger.info("스팟 인스턴스 모니터링 종료 요청됨 (test_environment.py 처리)")
def save_checkpoint(self, model, optimizer, scheduler, scaler, trainer, step: int):
"""체크포인트 저장"""
checkpoint_path = os.path.join(self.checkpoint_dir, f"checkpoint-{step}.pt")
state_dict = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict() if scheduler else None,
'scaler': scaler.state_dict() if scaler else None,
'step': step,
'trainer_state': trainer.get_state_dict()
}
torch.save(state_dict, checkpoint_path)
# 최신 체크포인트 심볼릭 링크 업데이트
latest_path = os.path.join(self.checkpoint_dir, "checkpoint-latest.pt")
# 심볼릭 링크 확인 및 안전하게 제거
if os.path.lexists(latest_path): # lexists는 심볼릭 링크 자체가 있는지 확인
try:
os.unlink(latest_path) # 심볼릭 링크를 안전하게 제거
except Exception as e:
logger.warning(f"심볼릭 링크 제거 중 오류: {e}")
# 심볼릭 링크 생성 시 예외 처리
try:
os.symlink(checkpoint_path, latest_path)
except FileExistsError:
# 여전히 존재한다면 강제로 다시 시도
logger.warning("심볼릭 링크가 여전히 존재함. 강제 삭제 후 재생성 시도...")
try:
os.remove(latest_path) # 강제 삭제
os.symlink(checkpoint_path, latest_path) # 다시 생성
except Exception as e:
logger.error(f"심볼릭 링크 강제 재생성 실패: {e}")
logger.info(f"심볼릭 링크 생성은 실패했지만, 체크포인트 {checkpoint_path}는 정상 저장됨")
logger.info(f"체크포인트가 저장되었습니다: {checkpoint_path}")
self.last_checkpoint_time = time.time()
def load_checkpoint(self, checkpoint_path, model, optimizer, scheduler, scaler, trainer):
"""지정된 경로의 체크포인트 로드"""
if not os.path.exists(checkpoint_path):
logger.warning(f"체크포인트 파일이 존재하지 않습니다: {checkpoint_path}")
return 0
logger.info(f"체크포인트 로드 중: {checkpoint_path}")
try:
# 기본 체크포인트 형식(.pt 파일) 로드 시도
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# 모델 로드
if 'model' in checkpoint:
model.load_state_dict(checkpoint['model'])
# 옵티마이저 로드
if 'optimizer' in checkpoint and optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer'])
# 스케줄러 로드
if 'scheduler' in checkpoint and scheduler is not None and checkpoint['scheduler'] is not None:
scheduler.load_state_dict(checkpoint['scheduler'])
# 스케일러 로드
if 'scaler' in checkpoint and scaler is not None and checkpoint['scaler'] is not None:
scaler.load_state_dict(checkpoint['scaler'])
# 트레이너 상태 로드
if 'trainer_state' in checkpoint and trainer is not None:
trainer.load_state_dict(checkpoint['trainer_state'])
step = checkpoint.get('step', 0)
logger.info(f"체크포인트 로드 완료. 스텝 {step}부터 재개합니다.")
return step
except Exception as e:
# 디렉토리인 경우 (최종 모델 디렉토리 로드 시도)
if os.path.isdir(checkpoint_path):
logger.info(f"체크포인트 경로가 디렉토리입니다: {checkpoint_path}")
try:
# 허깅페이스 모델 디렉토리에서 로드 시도
# PEFT 모델이 이미 적용된 경우 적절한 방식으로 처리
if hasattr(model, 'pretrained_model') or hasattr(model, 'base_model'):
# 이미 PEFT 모델인 경우
from peft import PeftModel, PeftConfig
logger.info(f"PEFT 모델에 대한 어댑터 로드 시도: {checkpoint_path}")
# 기존 모델에서 기본 모델 추출
if hasattr(model, 'pretrained_model'):
base_model = model.pretrained_model
else:
base_model = model.base_model if hasattr(model, 'base_model') else model
# 1. 기존 어댑터가 있는 경우 완전히 제거 (PEFT 0.4.0 방식)
if hasattr(model, 'unload'):
try:
model.unload()
logger.info("기존 PEFT 어댑터 제거 완료")
except Exception as e:
logger.warning(f"PEFT 어댑터 제거 중 오류(무시함): {e}")
# 2. 어댑터 구성 파일 확인
adapter_config_path = os.path.join(checkpoint_path, "adapter_config.json")
if not os.path.exists(adapter_config_path):
logger.warning(f"adapter_config.json 파일이 없어 일반 모델로 로드합니다: {checkpoint_path}")
model = base_model
model.load_state_dict(torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"), map_location="cpu"), strict=False)
return model, optimizer, scheduler, scaler, trainer
# 3. 어댑터 구성 정보 로드
try:
# PEFT 호환성 문제 해결을 위한 수정
# _prepare_model_for_peft_adapter 함수를 사용하지 않고 직접 어댑터 적용
# 어댑터 구성 파일 로드
peft_config = PeftConfig.from_pretrained(checkpoint_path)
logger.info(f"어댑터 구성 로드 완료: {peft_config.peft_type}, target_modules={getattr(peft_config, 'target_modules', 'N/A')}")
# 4. 기본 모델을 PEFT 어댑터에 맞게 준비 (가장 중요한 단계)
logger.info("기본 모델을 PEFT 어댑터에 맞게 준비 중...")
# PEFT 0.4.0에서는 이 함수가 없으므로 PeftModel이 자동으로 처리하도록 함
# 5. 체크포인트에서 가중치만 로드 (두 형식 지원: .bin 및 .safetensors)
bin_weights_path = os.path.join(checkpoint_path, "adapter_model.bin")
safetensors_weights_path = os.path.join(checkpoint_path, "adapter_model.safetensors")
# 6. 새 PeftModel 생성 (PEFT 0.4.0 방식)
model = PeftModel(base_model, peft_config, adapter_name="default")
# 7. 가중치 파일 여부 확인 및 로드
weights_loaded = False
# .bin 파일 확인
if os.path.exists(bin_weights_path):
logger.info(f"어댓터 가중치 (.bin) 로드 중: {bin_weights_path}")
adapter_weights = torch.load(bin_weights_path, map_location="cpu")
model.load_state_dict(adapter_weights, strict=False)
logger.info("어댓터 가중치 (.bin) 로드 완료")
weights_loaded = True
# .safetensors 파일 확인
elif os.path.exists(safetensors_weights_path):
try:
from safetensors import safe_open
from safetensors.torch import load_file
logger.info(f"어댓터 가중치 (.safetensors) 로드 중: {safetensors_weights_path}")
# safetensors 파일에서 가중치 로드
adapter_weights = load_file(safetensors_weights_path)
model.load_state_dict(adapter_weights, strict=False)
logger.info("어댓터 가중치 (.safetensors) 로드 완료")
weights_loaded = True
except Exception as e:
logger.error(f"safetensors 파일 로드 오류: {str(e)}")
# 가중치 파일을 찾지 못한 경우
if not weights_loaded:
logger.warning(f"어댓터 가중치 파일을 찾을 수 없음: {bin_weights_path} 또는 {safetensors_weights_path}")
logger.warning("새 어댓터로 초기화하여 계속합니다.")
# 8. 어댓터 활성화 확인 및 설정
if hasattr(model, 'active_adapter') and not model.active_adapter:
model.set_adapter("default")
logger.info("어댓터 활성화됨: default")
except Exception as e:
logger.error(f"PEFT 어댑터 로드 중 오류 발생: {str(e)}")
# 오류 발생 시 기본 모델 사용
model = base_model
else:
# 일반 모델인 경우
model.from_pretrained(checkpoint_path)
logger.info(f"모델이 로드되었습니다: {checkpoint_path}")
return 0 # 디렉토리에서 로드하면 스텝을 0으로 초기화
except Exception as inner_e:
logger.error(f"모델 로드 중 오류 발생: {inner_e}")
else:
logger.error(f"체크포인트 로드 중 오류 발생: {e}")
return 0
def load_latest_checkpoint(self, model, optimizer, scheduler, scaler, trainer):
"""최신 체크포인트 로드"""
latest_path = os.path.join(self.checkpoint_dir, "checkpoint-latest.pt")
if not os.path.exists(latest_path):
logger.info("로드할 체크포인트가 없습니다.")
return 0 # 시작 스텝
logger.info(f"체크포인트를 로드합니다: {latest_path}")
checkpoint = torch.load(latest_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
if scheduler and checkpoint['scheduler']:
scheduler.load_state_dict(checkpoint['scheduler'])
if scaler and checkpoint['scaler']:
scaler.load_state_dict(checkpoint['scaler'])
if 'trainer_state' in checkpoint:
trainer.load_state_dict(checkpoint['trainer_state'])
step = checkpoint['step']
logger.info(f"체크포인트 로드 완료, 스텝 {step}부터 재개합니다.")
return step
# CapaBoost 구현 클래스
class CapaBoostMaskGenerator:
"""
CapaBoost 정적 랜덤 마스크 생성기
ICLR 2024 논문의 완전한 구현
참고 논문: "Increasing Model Capacity for Free: A Simple Strategy for Parameter Efficient Fine-tuning"
"""
def __init__(self, sparsity: float = 0.6, num_masks: int = 4, seed: int = 42):
"""
Args:
sparsity: 마스크의 스파시티 (0.0-1.0, 0.6이 최적)
num_masks: 생성할 마스크 수 (4가 최적)
seed: 결정론적 마스크 생성을 위한 시드
"""
self.sparsity = sparsity
self.num_masks = num_masks
self.seed = seed
self.masks_cache = {}
logger.info(f"CapaBoost 초기화: 스파시티={sparsity}, 마스크 수={num_masks}, 시드={seed}")
def generate_static_masks(self, shape: Tuple[int, int], layer_name: str) -> List[torch.Tensor]:
"""
정적이고 결정론적인 랜덤 마스크 생성
논문의 Equation (3) 구현
"""
cache_key = f"{layer_name}_{shape}_{self.sparsity}_{self.num_masks}"
if cache_key in self.masks_cache:
return self.masks_cache[cache_key]
# 시드를 이용한 결정론적 마스크 생성
masks = []
for i in range(self.num_masks):
torch.manual_seed(self.seed + hash(layer_name) + i)
# 베르누이 분포를 이용한 스파스 마스크 생성
mask = torch.bernoulli(torch.full(shape, 1 - self.sparsity)).bool()
masks.append(mask)
self.masks_cache[cache_key] = masks
return masks
def apply_capaboost(self, weight: torch.Tensor, layer_name: str) -> torch.Tensor:
"""
CapaBoost 메커니즘 적용
z = Σ(w ⊙ mi)x + b 구현
"""
masks = self.generate_static_masks(weight.shape, layer_name)
# 마스킹된 가중치들의 합
masked_weights = []
for mask in masks:
masked_weight = weight * mask.float().to(weight.device)
masked_weights.append(masked_weight)
# 효과적인 랭크 증가를 위한 가중치 합산
effective_weight = sum(masked_weights)
return effective_weight
# CapaBoost 적용을 위한 Linear 래퍼 클래스
class CapaBoostLinear(nn.Module):
"""CapaBoost를 적용한 Linear 레이어"""
def __init__(self, in_features: int, out_features: int, bias: bool = True,
mask_generator: CapaBoostMaskGenerator = None, layer_name: str = None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.linear = nn.Linear(in_features, out_features, bias)
self.mask_generator = mask_generator
self.layer_name = layer_name or f"linear_{id(self)}"
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.mask_generator is not None:
effective_weight = self.mask_generator.apply_capaboost(self.linear.weight, self.layer_name)
return F.linear(x, effective_weight, self.linear.bias)
else:
return self.linear(x)
# Fisher Information Matrix 계산 클래스
class FisherInformationComputer:
"""
정확한 Fisher Information Matrix 계산
EWC와 LwF에서 사용
"""
def __init__(self, mode: str = "empirical"):
"""
Args:
mode: Fisher 계산 모드 ('exact', 'empirical', 'diagonal')
"""
self.mode = mode
logger.info(f"Fisher Information 계산기 초기화: 모드={mode}")
def compute_fisher_diagonal(self, model: nn.Module, dataloader,
criterion, num_samples: int = 1000) -> Dict[str, torch.Tensor]:
"""
대각 Fisher Information Matrix 계산
F_i,i = E[(∂log p(y|x,θ)/∂θ_i)²] 구현
"""
fisher = {}
for name, param in model.named_parameters():
if param.requires_grad:
fisher[name] = torch.zeros_like(param)
model.eval()
samples_processed = 0
with torch.no_grad():
for data, target in tqdm(dataloader, desc="Fisher 계산 중"):
if samples_processed >= num_samples:
break
data, target = data.cuda(), target.cuda()
batch_size = data.size(0)
samples_processed += batch_size
# 모드별 Fisher 계산
if self.mode == "exact":
# 정확한 Fisher Information 계산
output = model(data)
log_probs = F.log_softmax(output, dim=1)
for class_idx in range(output.size(1)):
model.zero_grad()
class_log_prob = log_probs[:, class_idx].mean()
class_log_prob.backward(retain_graph=(class_idx < output.size(1) - 1))
for name, param in model.named_parameters():
if param.grad is not None:
fisher[name] += param.grad.pow(2) * batch_size / num_samples
elif self.mode == "empirical":
# 경험적 Fisher Information (실제 레이블 사용)
for i in range(batch_size):
model.zero_grad()
output = model(data[i:i+1])
loss = criterion(output, target[i:i+1])
loss.backward()
for name, param in model.named_parameters():
if param.grad is not None:
fisher[name] += param.grad.pow(2) / num_samples
elif self.mode == "diagonal":
# 대각 근사 Fisher (메모리 효율적)
output = model(data)
log_probs = F.log_softmax(output, dim=1)
# 실제 타겟에 대한 로그 확률의 그래디언트
model.zero_grad()
target_log_probs = log_probs[torch.arange(batch_size), target]
target_log_probs.mean().backward()
for name, param in model.named_parameters():
if param.grad is not None:
fisher[name] += param.grad.pow(2) * batch_size / num_samples
model.train()
return fisher
# Meta-Experience Replay 구현
class MetaExperienceReplay:
"""
Meta-Experience Replay 완전 구현
Stanford NLP의 원본 논문 기반
"""
def __init__(self, buffer_size: int = 5120, beta: float = 0.1,
gamma: float = 0.1, s: int = 10):
"""
Args:
buffer_size: 메모리 버퍼 크기
beta: 배치 내 업데이트 강도
gamma: 배치 간 업데이트 강도
s: Reptile 스텝 수
"""
self.buffer_size = buffer_size
self.buffer = []
self.beta = beta
self.gamma = gamma
self.s = s
self.class_counter = defaultdict(int)
logger.info(f"MER 초기화: 버퍼 크기={buffer_size}, β={beta}, γ={gamma}, s={s}")
def reservoir_sampling(self, x: torch.Tensor, y: torch.Tensor):
"""저장소 샘플링을 이용한 메모리 버퍼 관리"""
for i in range(x.size(0)):
sample = (x[i].clone(), y[i].clone())
if len(self.buffer) < self.buffer_size:
self.buffer.append(sample)
self.class_counter[y[i].item()] += 1
else:
# 클래스 밸런스를 유지하기 위한 선택적 대체
if self.should_replace(y[i].item()):
# 같은 클래스의 샘플 대체
indices = [j for j, (_, label) in enumerate(self.buffer) if label == y[i]]
if indices:
idx = random.choice(indices)
self.buffer[idx] = sample
def should_replace(self, class_label: int) -> bool:
"""클래스 밸런스를 위한 대체 결정"""
# 적게 등장한 클래스는 대체 확률 높임
counts = list(self.class_counter.values())
if not counts:
return True
avg_count = sum(counts) / len(counts)
current_count = self.class_counter[class_label]
# 평균보다 적게 등장한 클래스는 대체할 확률 높임
if current_count < avg_count:
return random.random() < 0.8
else:
return random.random() < 0.2
def sample_batch(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""버퍼에서 배치 샘플링"""
if len(self.buffer) == 0:
return None, None
indices = np.random.choice(len(self.buffer),
min(batch_size, len(self.buffer)),
replace=False)
x_batch = torch.stack([self.buffer[i][0] for i in indices])
y_batch = torch.stack([self.buffer[i][1] for i in indices])
return x_batch, y_batch
def meta_update(self, model: nn.Module, optimizer, criterion,
current_batch: Tuple[torch.Tensor, torch.Tensor],
memory_batch: Tuple[torch.Tensor, torch.Tensor]) -> float:
"""
MER 메타 업데이트 수행
Reptile 알고리즘 기반
"""
x_curr, y_curr = current_batch
x_mem, y_mem = memory_batch
# 현재 파라미터 저장
old_params = {name: param.clone() for name, param in model.named_parameters()}
# Within-batch Reptile update
for _ in range(self.s):
# 현재 배치 업데이트
optimizer.zero_grad()
output_curr = model(x_curr)
loss_curr = criterion(output_curr, y_curr)
loss_curr.backward()
optimizer.step()
if x_mem is not None:
# 메모리 배치 업데이트
inner_params = {name: param.clone() for name, param in model.named_parameters()}
optimizer.zero_grad()
output_mem = model(x_mem)
loss_mem = criterion(output_mem, y_mem)
loss_mem.backward()
optimizer.step()
# Within-batch interpolation
for name, param in model.named_parameters():
param.data = inner_params[name] + self.beta * (param.data - inner_params[name])
# Across-batch Reptile update (메타 업데이트)
if x_mem is not None:
for name, param in model.named_parameters():
param.data = old_params[name] + self.gamma * (param.data - old_params[name])
return loss_curr.item()
def get_state_dict(self):
"""상태 저장을 위한 딕셔너리 반환"""
return {
'buffer': self.buffer,
'class_counter': dict(self.class_counter),
'buffer_size': self.buffer_size,
'beta': self.beta,
'gamma': self.gamma,
's': self.s
}
def load_state_dict(self, state_dict):
"""저장된 상태 로드"""
self.buffer = state_dict['buffer']
self.class_counter = defaultdict(int, state_dict['class_counter'])
self.buffer_size = state_dict['buffer_size']
self.beta = state_dict['beta']
self.gamma = state_dict['gamma']
self.s = state_dict['s']
# 다이나믹 배치 사이즈 관리 클래스
class DynamicBatchSizeManager:
"""
메모리 사용량에 따라 배치 크기를 동적으로 조정
OOM 오류 방지 및 최적 성능 보장
"""
def __init__(self, initial_batch_size: int,
target_gpu_util: float = 0.85,
max_batch_size: int = 128,
min_batch_size: int = 1):
self.batch_size = initial_batch_size
self.target_gpu_util = target_gpu_util
self.max_batch_size = max_batch_size
self.min_batch_size = min_batch_size
self.adjustment_counter = 0
self.history = deque(maxlen=10) # 최근 조정 내역
logger.info(f"동적 배치 사이즈 관리자 초기화: 시작={initial_batch_size}, 목표 GPU 사용률={target_gpu_util}")
def update(self, oom_occurred: bool = False) -> int:
"""
현재 GPU 메모리 상태에 따라 배치 크기 업데이트
Args:
oom_occurred: OOM 오류 발생 여부
Returns:
새로운 배치 크기
"""
if oom_occurred:
# OOM 발생 시 배치 크기 즉시 감소
new_batch_size = max(self.min_batch_size, int(self.batch_size * 0.5))
self.history.append(("OOM", self.batch_size, new_batch_size))
self.batch_size = new_batch_size
logger.warning(f"OOM 감지, 배치 크기 감소: {self.batch_size}")
return self.batch_size
# 5번마다 배치 크기 조정 (너무 자주 조정하지 않도록)
self.adjustment_counter += 1
if self.adjustment_counter % 5 != 0:
return self.batch_size
# GPU 메모리 사용량 확인
try:
gpus = GPUtil.getGPUs()
if not gpus:
return self.batch_size
gpu = gpus[0] # 첫 번째 GPU 사용 가정
mem_util = gpu.memoryUtil
# 메모리 사용량에 따른 배치 크기 조정
if mem_util > self.target_gpu_util + 0.05:
# 메모리 사용량이 너무 높음, 배치 크기 감소
new_batch_size = max(self.min_batch_size, int(self.batch_size * 0.9))
action = "감소"
elif mem_util < self.target_gpu_util - 0.1 and self.batch_size < self.max_batch_size:
# 메모리 사용량이 낮음, 배치 크기 증가
new_batch_size = min(self.max_batch_size, int(self.batch_size * 1.1))
action = "증가"
else:
# 적정 범위 내, 유지
return self.batch_size
if new_batch_size != self.batch_size:
self.history.append((action, self.batch_size, new_batch_size))
self.batch_size = new_batch_size
logger.info(f"배치 크기 {action}: {self.batch_size} (GPU 메모리 사용률: {mem_util:.2f})")
except Exception as e:
logger.warning(f"배치 크기 업데이트 중 오류 발생: {e}")
return self.batch_size
# 적응적 학습률 스케줄러
class AdaptiveLRScheduler:
"""
선형 감소 + 웜업 + 종료 시 급격한 감소를 포함한
적응적 학습률 스케줄링
"""
def __init__(self, optimizer, total_steps: int,
warmup_steps: int = 0,
min_lr_ratio: float = 0.1,
end_lr_ratio: float = 0.01):
"""
Args:
optimizer: 학습률을 조정할 옵티마이저
total_steps: 총 훈련 스텝 수
warmup_steps: 웜업 스텝 수
min_lr_ratio: 최소 학습률 비율
end_lr_ratio: 종료 시 학습률 비율
"""
self.optimizer = optimizer
self.total_steps = total_steps
self.warmup_steps = warmup_steps
self.min_lr_ratio = min_lr_ratio
self.end_lr_ratio = end_lr_ratio
self.initial_lr = [group['lr'] for group in optimizer.param_groups]
self.scheduler = LambdaLR(optimizer, self.lr_lambda)
logger.info(f"적응적 학습률 스케줄러 초기화: 웜업={warmup_steps}, 총 스텝={total_steps}")
def lr_lambda(self, current_step: int):
"""학습률 계산 함수"""
if current_step < self.warmup_steps:
# 웜업 구간: 0 -> 1
return float(current_step) / float(max(1, self.warmup_steps))
# 훈련 막바지 10%에서는 더 급격히 감소
if current_step > 0.9 * self.total_steps:
progress = (current_step - 0.9 * self.total_steps) / (0.1 * self.total_steps)
decay = max(self.end_lr_ratio, (1.0 - progress) * (1.0 - 0.9) + 0.9 * self.end_lr_ratio)
return decay
# 메인 구간: 선형 감소 (1 -> min_lr_ratio)
progress = (current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
decay = max(self.min_lr_ratio, 1.0 - progress)
return decay
def step(self):
self.scheduler.step()
def get_last_lr(self):
return self.scheduler.get_last_lr()
def state_dict(self):
return self.scheduler.state_dict()
def load_state_dict(self, state_dict):
self.scheduler.load_state_dict(state_dict)
# Feature Adapter 구현
class FeatureAdapter(nn.Module):
"""
언어 모델의 큰 출력 차원을 분류기가 기대하는 작은 차원으로 변환하는 어댑터
행렬곱 차원 불일치 문제를 해결하기 위한 중간 변환 레이어
"""
def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = 2048):
"""
Args:
input_dim: 입력 차원 (언어 모델의 vocab_size 또는 출력 차원)
output_dim: 출력 차원 (분류기가 기대하는 특징 차원)
hidden_dim: 중간 은닉층 차원
"""
super().__init__()
# 입력 차원이 매우 큰 경우 (vocab_size 등) 단계적 차원 축소
if input_dim > 10000:
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 2, output_dim)
)
else:
# 입력 차원이 작은 경우 단순한 변환
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim)
)
logger.info(f"FeatureAdapter 초기화: {input_dim} → {output_dim} (hidden: {hidden_dim})")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
특징 변환 수행
Args:
x: 입력 텐서 - 다양한 형태의 모델 출력을 처리
Returns:
변환된 특징 텐서 (batch_size, output_dim)
"""
# 입력 텐서 형태 정규화
if len(x.shape) == 3: # (batch_size, seq_len, hidden_dim)
# 시퀀스의 마지막 토큰 사용 (언어 모델의 일반적 패턴)
x = x[:, -1, :]
elif len(x.shape) > 3:
# 더 고차원인 경우 배치 차원을 제외한 모든 차원을 평탄화
batch_size = x.size(0)
x = x.view(batch_size, -1)
elif len(x.shape) == 1:
# 1차원인 경우 배치 차원 추가
x = x.unsqueeze(0)
# 차원 변환 적용
return self.layers(x)
# Nearest Mean Classifier 구현
class NearestMeanClassifier(nn.Module):
"""
Stability Gap 완화를 위한 Nearest Mean Classifier
Supervised Contrastive Replay와 결합 시 최고 성능
"""
def __init__(self, num_classes: int, feat_dim: int):
"""
Args:
num_classes: 클래스 수
feat_dim: 특징 벡터 차원
"""
super().__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.register_buffer('class_means', torch.zeros(num_classes, feat_dim))
self.register_buffer('class_counts', torch.zeros(num_classes))
logger.info(f"Nearest Mean Classifier 초기화: 클래스={num_classes}, 차원={feat_dim}")
def update_class_means(self, features: torch.Tensor, labels: torch.Tensor):
"""클래스별 평균 특징 벡터 업데이트"""
for c in range(self.num_classes):
idx = (labels == c)
if idx.sum() > 0:
class_feats = features[idx]
self.class_means[c] = (self.class_means[c] * self.class_counts[c] + class_feats.sum(0)) / (self.class_counts[c] + idx.sum())
self.class_counts[c] += idx.sum()
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
가장 가까운 클래스 평균으로 분류
Args:
features: 모델에서 추출한 특징 벡터
Returns:
각 클래스의 로짓 값
"""
# L2 정규화
norm_features = F.normalize(features, p=2, dim=1)
norm_class_means = F.normalize(self.class_means, p=2, dim=1)
# 코사인 유사도 계산
logits = torch.matmul(norm_features, norm_class_means.t())
# 온도 스케일링 (선택적)
# logits = logits / 0.07 # temperature
return logits
def get_state_dict(self):
"""상태 저장"""
return {
'class_means': self.class_means.clone(),
'class_counts': self.class_counts.clone(),
'num_classes': self.num_classes,
'feat_dim': self.feat_dim
}
def load_state_dict(self, state_dict, strict=True):
"""상태 로드"""
self.num_classes = state_dict['num_classes']
self.feat_dim = state_dict['feat_dim']
self.class_means = state_dict['class_means'].clone()
self.class_counts = state_dict['class_counts'].clone()
# Supervised Contrastive Replay 손실
class SupervisedContrastiveLoss(nn.Module):
"""
NCM과 함께 사용하는 지도 대조 손실
같은 클래스의 샘플들은 가까게, 다른 클래스의 샘플들은 멀게 배치
"""
def __init__(self, temperature: float = 0.07):
"""
Args:
temperature: 대조 손실 온도 파라미터
"""
super().__init__()
self.temperature = temperature
logger.info(f"지도 대조 손실 초기화: 온도={temperature}")
def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
지도 대조 손실 계산
Args:
features: 정규화된 특징 벡터 (batch_size, feat_dim)
labels: 클래스 레이블 (batch_size)
Returns:
대조 손실 값
"""
# L2 정규화
features = F.normalize(features, p=2, dim=1)
# 유사도 행렬 계산
similarity_matrix = torch.matmul(features, features.t()) / self.temperature
# 마스크 생성: 같은 클래스 = 1, 다른 클래스 = 0
batch_size = features.size(0)
labels_matrix = labels.expand(batch_size, batch_size).eq(labels.expand(batch_size, batch_size).t())
# 자기 자신은 제외
labels_matrix.fill_diagonal_(False)
# 양의 쌍이 있는지 확인
pos_pairs = labels_matrix.sum(dim=1) > 0
if not pos_pairs.any():
return torch.tensor(0.0, device=features.device)
# 마스크가 있는 샘플만 선택
labels_matrix = labels_matrix[pos_pairs]
similarity_matrix = similarity_matrix[pos_pairs]
# 로그-소프트맥스와 마스크 적용
logits_max, _ = torch.max(similarity_matrix, dim=1, keepdim=True)
logits = similarity_matrix - logits_max.detach() # 수치 안정성
exp_logits = torch.exp(logits)
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# 양의 쌍의 로그 확률 평균
mean_log_prob_pos = (labels_matrix * log_prob).sum(1) / labels_matrix.sum(1)
# 손실 계산
loss = -mean_log_prob_pos.mean()
return loss
# EWC (Elastic Weight Consolidation) 구현
class ElasticWeightConsolidation:
"""
Elastic Weight Consolidation
이전 태스크의 중요한 가중치를 보호하여 Catastrophic Forgetting 완화
"""
def __init__(self, model: nn.Module, fisher_computer: FisherInformationComputer,
lambda_ewc: float = 5000.0, normalize_fisher: bool = True):
"""
Args:
model: 모델
fisher_computer: Fisher Information 계산기
lambda_ewc: EWC 강도 계수
normalize_fisher: Fisher 정규화 여부
"""
self.model = model
self.fisher_computer = fisher_computer
self.lambda_ewc = lambda_ewc
self.normalize_fisher = normalize_fisher
self.fisher_dict = {} # task_id -> fisher
self.optpar_dict = {} # task_id -> optimal parameters
logger.info(f"EWC 초기화: λ={lambda_ewc}, 정규화={normalize_fisher}")
def register_task(self, task_id: int, dataloader, criterion):
"""
새 태스크 등록 및 Fisher Information 계산
Args:
task_id: 태스크 ID
dataloader: 태스크 데이터 로더
criterion: 손실 함수
"""
logger.info(f"태스크 {task_id} Fisher Information 계산 중...")
# Fisher Information 계산
fisher = self.fisher_computer.compute_fisher_diagonal(
self.model, dataloader, criterion
)
# Fisher 정규화 (선택적)
if self.normalize_fisher:
for name, f in fisher.items():
fisher[name] = f / f.sum()
# 현재 파라미터 저장
optpar = {}
for name, param in self.model.named_parameters():
if param.requires_grad:
optpar[name] = param.data.clone()
# 태스크 정보 저장
self.fisher_dict[task_id] = fisher
self.optpar_dict[task_id] = optpar
logger.info(f"태스크 {task_id} 등록 완료")
def compute_ewc_loss(self) -> torch.Tensor:
"""
EWC 정규화 손실 계산
이전 태스크의 중요 파라미터 변화에 페널티 부여
Returns:
EWC 손실
"""
loss = torch.tensor(0., device=next(self.model.parameters()).device)
# 등록된 태스크가 없으면 손실 없음
if not self.fisher_dict:
return loss
# 각 태스크에 대한 EWC 손실 계산
for task_id, fisher in self.fisher_dict.items():
optpar = self.optpar_dict[task_id]
for name, param in self.model.named_parameters():
if name in fisher and param.requires_grad:
try:
# 텐서 처리를 안전하게 수행
# 가능한 경우 먼저 머신 에러 체크
if fisher[name].shape != param.shape or optpar[name].shape != param.shape:
logger.warning(f"EWC 손실 계산 시 텐서 형태 불일치 무시: {name}, {fisher[name].shape}, {param.shape}, {optpar[name].shape}")
continue
# 두 텐서의 차이
diff = param - optpar[name]
# 제곱 계산
squared_diff = diff.pow(2)
# Fisher로 가중치 계산
weighted_squared_diff = fisher[name] * squared_diff
# 합계 계산
param_loss = weighted_squared_diff.sum()
# 확인 후 추가
if torch.isfinite(param_loss).all(): # NaN 확인
loss += param_loss
else:
logger.warning(f"EWC 손실에서 NaN/Inf 발견, 무시함: {name}")
except Exception as e:
logger.warning(f"EWC 손실 계산 중 오류 발생(무시): {name}, {str(e)}")
continue
# 태스크 수로 정규화
num_tasks = len(self.fisher_dict)
if num_tasks > 0:
loss = self.lambda_ewc * loss / num_tasks
return loss
def get_state_dict(self):
"""상태 저장"""
return {
'fisher_dict': self.fisher_dict,
'optpar_dict': self.optpar_dict,
'lambda_ewc': self.lambda_ewc,
'normalize_fisher': self.normalize_fisher
}
def load_state_dict(self, state_dict):
"""상태 로드"""
self.fisher_dict = state_dict['fisher_dict']
self.optpar_dict = state_dict['optpar_dict']
self.lambda_ewc = state_dict['lambda_ewc']
self.normalize_fisher = state_dict['normalize_fisher']
# 지속 학습 메트릭 계산기
class ContinualLearningMetrics:
"""
지속 학습 전용 평가 메트릭
ACC, BWT, FWT, Stability Gap 등 포괄적 평가
"""
def __init__(self):
self.task_accuracies = defaultdict(list)
self.stability_gaps = []
self.task_names = []
def update(self, task_id: int, test_results: Dict[int, float]):
"""
태스크별 정확도 업데이트
Args:
task_id: 현재 학습 중인 태스크 ID
test_results: 모든 이전 태스크에 대한 테스트 결과 {task_id: accuracy}
"""
for test_task_id, accuracy in test_results.items():
self.task_accuracies[test_task_id].append(accuracy)
def record_stability_gap(self, task_id: int, accuracies: List[float]):
"""
Stability Gap 기록
Args:
task_id: 태스크 ID
accuracies: 학습 중 이전 태스크의 정확도 변화
"""
if len(accuracies) < 3: # 너무 적은 데이터
return
initial_acc = accuracies[0]
min_acc = min(accuracies)
final_acc = accuracies[-1]
gap = initial_acc - min_acc
recovery = final_acc - min_acc
self.stability_gaps.append({
'task_id': task_id,
'gap': gap,
'recovery': recovery,
'recovery_ratio': recovery / gap if gap > 0 else 1.0
})
def compute_metrics(self) -> Dict[str, float]:
"""
BWT, FWT, ACC, Stability Gap 계산
Returns:
계산된 메트릭들
"""
num_tasks = len(self.task_accuracies)
# Average Accuracy (ACC)
final_accuracies = [self.task_accuracies[i][-1] for i in range(num_tasks)]
acc = np.mean(final_accuracies)
# Backward Transfer (BWT)
bwt = 0
for i in range(num_tasks - 1):
bwt += self.task_accuracies[i][-1] - self.task_accuracies[i][i]
bwt /= (num_tasks - 1) if num_tasks > 1 else 1
# Forward Transfer (FWT) - 첫 번째 에포크 성능 기반
fwt = 0
baseline_acc = 0 # 무작위 성능 가정
for i in range(1, num_tasks):
if len(self.task_accuracies[i]) > 0:
fwt += self.task_accuracies[i][0] - baseline_acc
fwt /= (num_tasks - 1) if num_tasks > 1 else 1
# Average Stability Gap
avg_gap = 0
avg_recovery_ratio = 0
if self.stability_gaps:
avg_gap = np.mean([gap['gap'] for gap in self.stability_gaps])
avg_recovery_ratio = np.mean([gap['recovery_ratio'] for gap in self.stability_gaps])
return {
'ACC': acc,
'BWT': bwt,
'FWT': fwt,
'final_accuracies': final_accuracies,
'stability_gap': avg_gap,
'recovery_ratio': avg_recovery_ratio
}
def get_state_dict(self):
"""상태 저장"""
return {
'task_accuracies': dict(self.task_accuracies),
'stability_gaps': self.stability_gaps,
'task_names': self.task_names
}
def load_state_dict(self, state_dict):
"""저장된 상태 로드"""
self.task_accuracies = defaultdict(list, state_dict['task_accuracies'])
self.stability_gaps = state_dict['stability_gaps']
self.task_names = state_dict['task_names']
# 메인 지속 학습 트레이너 클래스
class ContinualLearner:
"""
CapaBoost + Continual Learning 통합 시스템
높은 성공률과 안정성을 위한 모든 기법 통합
"""
def __init__(self, config: Dict[str, Any]):
"""
Args:
config: 설정 파라미터
"""
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 시드 설정
set_seed(config.get('seed', 42))
# 모델 구성
self.build_model()
# 핵심 컴포넌트 초기화
self.setup_components()
# 메트릭 및 로깅 설정
self.metrics = ContinualLearningMetrics()
self.writer = SummaryWriter(config.get('log_dir', 'runs/continual_learning'))
# AWS Spot Instance 중단 핸들러
self.spot_handler = SpotInterruptionHandler(
checkpoint_dir=config.get('checkpoint_dir', 'checkpoints'),
checkpoint_freq=config.get('checkpoint_freq', 100)
)
# 배치 크기 관리자
self.batch_manager = DynamicBatchSizeManager(
initial_batch_size=config.get('batch_size', 32),
target_gpu_util=config.get('target_gpu_util', 0.85),
max_batch_size=config.get('max_batch_size', 128)
)
# 학습 상태
self.current_task_id = 0
self.global_step = 0
self.best_accuracies = {}
logger.info("ContinualLearner 초기화 완료")
def build_model(self):
"""모델 구성 - DeepSeek-Coder 모델 로드"""
model_config = self.config.get('model', {})
model_name = model_config.get('name', 'deepseek-coder')
mode = self.config.get('mode', 'prompt') # 학습 모드 확인
# DeepSeek-Coder 모델 로딩 시도
if TRANSFORMERS_AVAILABLE and model_name == 'deepseek-coder':
try:
# AWS 환경에 맞는 기본 모델 경로 설정 (절대경로)
base_model = "/home/ubuntu/deepseek-coder/models/deepseek-coder-6.7b-instruct" # 기본 모델
# 홈 디렉토리 기반 모델 루트 경로 설정
model_base_path = os.path.expanduser("~/deepseek-coder/models")
# 모드별 모델 경로 확인
if mode == 'complete':
# 코드 자동완성 모드
model_path = os.path.join(model_base_path, "autocomplete-finetuned/final_model")
logger.info(f"[자동완성] 모드 - 모델 경로: {model_path}")
elif mode == 'prompt':
# 프롬프트 모드
model_path = os.path.join(model_base_path, "prompt-finetuned/final_model")
logger.info(f"[프롬프트] 모드 - 모델 경로: {model_path}")
elif mode == 'comment':
# 주석 모드
model_path = os.path.join(model_base_path, "comment-finetuned/final_model")
logger.info(f"[주석] 모드 - 모델 경로: {model_path}")
elif mode == 'error_fix':
# 오류 수정 모드
model_path = os.path.join(model_base_path, "error-fix-finetuned/final_model")
logger.info(f"[오류수정] 모드 - 모델 경로: {model_path}")
else:
# 알 수 없는 모드일 경우 기본으로 prompt 모드 사용
model_path = os.path.join(model_base_path, "prompt-finetuned/final_model")
logger.info(f"알 수 없는 모드: {mode}, 프롬프트 모델 사용: {model_path}")
# 경로가 없으면 기본 모델 사용 예정임을 로깅
if not os.path.exists(model_path):
logger.warning(f"지정된 모델 경로가 존재하지 않습니다: {model_path}")
logger.warning(f"기본 모델 {base_model}를 사용합니다.")
use_base_model = True
else:
use_base_model = False
# 양자화 설정
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
try:
if use_base_model:
# 기본 모델 로드
logger.info(f"기본 모델 로드 시도: {base_model}")
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
self.model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="auto"
)
logger.info(f"기본 모델 로드 성공!")
else:
# 학습된 모델 로드
logger.info(f"학습된 모델 로드 시도: {model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=bnb_config,
device_map="auto"
)
logger.info(f"학습된 {mode} 모델 로드 성공!")
except Exception as e:
# 오류 발생시 기본 모델로 돌아가기
logger.warning(f"모델 로드 오류: {str(e)}")
logger.info(f"기본 모델 로드 시도(콜백): {base_model}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
self.model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="auto"
)
logger.info(f"기본 모델 로드 성공(fallback)!")
except Exception as nested_e:
logger.error(f"기본 모델 로드도 실패: {str(nested_e)}")
# 모델 로드가 완전히 실패한 경우 - dummy 모델 사용
self._create_dummy_model()
logger.warning("모든 모델 로드 시도 실패. 더미 모델로 최소 한의 실행을 보장합니다.")
# LoRA 설정
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM"
)
# LoRA 적용
self.model = get_peft_model(self.model, lora_config)
# 모델 설정 저장
self.feat_dim = 4096 # DeepSeek-Coder hidden size
self.vocab_size = self.model.config.vocab_size # 언어 모델 어휘 크기
# 언어 모델 로그 정보
logger.info(f"언어 모델 어휘 크기: {self.vocab_size}, 히든 차원: {self.feat_dim}")
# 분류기와 FeatureAdapter 제거 - 언어 모델 내장 손실만 사용하도록 수정
logger.info(f"DeepSeek-Coder 모델 로드 성공")
except Exception as e:
logger.error(f"DeepSeek-Coder 모델 로드 오류: {str(e)}")
# 오류 발생시 더미 모델로 돌아감
self._create_dummy_model()
else:
# Transformers 미설치 또는 지원되지 않는 모델일 경우 더미 모델 사용
logger.warning(f"Transformers 라이브러리 미설치 또는 지원되지 않는 모델: {model_name}")
self._create_dummy_model()
def _create_dummy_model(self):
"""더미 모델 생성 (테스트용)"""
logger.warning("더미 모델을 사용합니다. 실제 학습에는 적합하지 않습니다.")
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU()
).to(self.device)
self.feat_dim = 256
self.vocab_size = 1000 # 더미 모델의 어휘 사이즈
# 분류기와 Feature Adapter 제거
# 언어 모델만 사용하고 내장 손실 함수 사용
logger.info(f"더미 모델 생성 완료 - 히든 차원: {self.feat_dim}")
# CapaBoost 적용
if self.config.get('use_capaboost', True):
self.apply_capaboost()
# 그래디언트 체크포인팅 (메모리 최적화)
if self.config.get('use_gradient_checkpointing', True):
# Hugging Face 모델은 gradient_checkpointing_enable 메서드가 있지만 Sequential는 없음
# 안전하게 메서드 존재 여부 확인 후 호출
if hasattr(self.model, 'gradient_checkpointing_enable'):
self.model.gradient_checkpointing_enable()
logger.info("그래디언트 체크포인팅이 활성화되었습니다.")
else:
logger.warning("이 모델은 그래디언트 체크포인팅을 지원하지 않습니다.")
def apply_capaboost(self):
"""CapaBoost를 모델의 선형 레이어에 적용"""
capaboost_config = self.config.get('capaboost', {})
self.mask_generator = CapaBoostMaskGenerator(
sparsity=capaboost_config.get('sparsity', 0.6),
num_masks=capaboost_config.get('num_masks', 4),
seed=capaboost_config.get('seed', 42)
)
# 각 선형 레이어를 CapaBoost 레이어로 교체
for name, module in list(self.model.named_children()):
if isinstance(module, nn.Linear):
# CapaBoost Linear로 교체
capaboost_linear = CapaBoostLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
mask_generator=self.mask_generator,
layer_name=name
)
# 가중치 복사
capaboost_linear.linear.weight.data.copy_(module.weight.data)
if module.bias is not None:
capaboost_linear.linear.bias.data.copy_(module.bias.data)
# 모듈 교체
if isinstance(self.model, nn.Sequential):
self.model._modules[name] = capaboost_linear
else:
setattr(self.model, name, capaboost_linear)
logger.info("CapaBoost가 모든 선형 레이어에 적용되었습니다.")
def setup_components(self):
"""핵심 컴포넌트 초기화"""
# 옵티마이저 설정
# config.yaml에서 learning_rate 키를 직접 가져옴 (문자열을 float로 변환)
learning_rate = float(self.config.get('learning_rate', '2e-4'))
weight_decay = float(self.config.get('weight_decay', '0.01'))
logger.info(f"학습률 초기화: {learning_rate}, 가중치 감소율: {weight_decay}")
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=learning_rate,
weight_decay=weight_decay
)
# Fisher Information 계산기
fisher_config = self.config.get('fisher', {})
self.fisher_computer = FisherInformationComputer(
mode=fisher_config.get('mode', 'empirical')
)
# EWC 설정
ewc_config = self.config.get('ewc', {})
self.ewc = ElasticWeightConsolidation(
model=self.model,
fisher_computer=self.fisher_computer,
lambda_ewc=ewc_config.get('lambda', 5000.0),
normalize_fisher=ewc_config.get('normalize', True)
)
# Meta-Experience Replay
mer_config = self.config.get('mer', {})
self.mer = MetaExperienceReplay(
buffer_size=mer_config.get('buffer_size', 5120),
beta=mer_config.get('beta', 0.1),
gamma=mer_config.get('gamma', 0.1),
s=mer_config.get('s', 10)
)
# 분류기 관련 손실 함수 제거 - 언어 모델 내장 손실만 사용
logger.info("언어 모델링 내장 손실 함수 사용 설정")
# 혼합 정밀도 훈련
self.use_amp = self.config.get('use_amp', True) and torch.cuda.is_available()
self.scaler = GradScaler() if self.use_amp else None
if self.use_amp:
logger.info("혼합 정밀도 훈련(AMP)이 활성화되었습니다.")
# 학습률 스케줄러
scheduler_config = self.config.get('scheduler', {})
self.scheduler = AdaptiveLRScheduler(
optimizer=self.optimizer,
total_steps=scheduler_config.get('total_steps', 10000),
warmup_steps=scheduler_config.get('warmup_steps', 100),
min_lr_ratio=scheduler_config.get('min_lr_ratio', 0.1),
end_lr_ratio=scheduler_config.get('end_lr_ratio', 0.01)
)
def train_task(self, task_id: int, train_loader: DataLoader, val_loader: DataLoader, test_loaders: Dict[int, DataLoader]):
"""
단일 태스크에 대한 훈련 수행
Args:
task_id: 태스크 ID
train_loader: 훈련 데이터 로더
val_loader: 검증 데이터 로더
test_loaders: 이전 태스크 테스트 데이터 로더 {task_id: loader}
"""
self.current_task_id = task_id
# 태스크 구성
max_epochs = self.config.get('max_epochs', 10)
patience = self.config.get('patience', 3)
# 조기 종료 추적
best_val_acc = 0.0
patience_counter = 0
# 태스크 훈련 시작
logger.info(f"=== 태스크 {task_id} 훈련 시작 ===")
for epoch in range(max_epochs):
# 훈련 에포크
train_loss, train_acc = self.train_epoch(
train_loader=train_loader,
epoch=epoch,
task_id=task_id
)
# 검증
val_loss, val_acc = self.evaluate(val_loader)
# 모든 이전 태스크에 대한 테스트
test_results = {}
for test_task_id, test_loader in test_loaders.items():
_, test_acc = self.evaluate(test_loader)
test_results[test_task_id] = test_acc
# 메트릭 업데이트
self.metrics.update(task_id, test_results)
# 안정성 측정을 위해 이전 태스크 정확도 기록
if task_id > 0 and 0 in test_results:
self.metrics.record_stability_gap(
task_id,
self.metrics.task_accuracies[0]
)
# 로깅
metrics = self.metrics.compute_metrics()
self.log_metrics(epoch, train_loss, train_acc, val_loss, val_acc, test_results, metrics)
# 조기 종료 확인
if val_acc > best_val_acc:
best_val_acc = val_acc
patience_counter = 0
# 최적 모델 저장
self.save_model(f"best_model_task_{task_id}.pt")
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(f"조기 종료: {patience}에포크 동안 개선 없음")
break
# 태스크 완료 후 EWC 등록
self.ewc.register_task(task_id, train_loader, self.ce_loss)
# 최종 성능 저장
self.best_accuracies[task_id] = best_val_acc
# 종합 메트릭 로깅
final_metrics = self.metrics.compute_metrics()
logger.info(f"=== 태스크 {task_id} 훈련 완료 ===")
logger.info(f"최종 메트릭: {final_metrics}")
# 최종 모델 저장
self.save_model(f"final_model_task_{task_id}.pt")
def train_epoch(self, train_loader: DataLoader, epoch: int, task_id: int):
"""
단일 에포크 훈련
Args:
train_loader: 훈련 데이터 로더
epoch: 현재 에포크
task_id: 태스크 ID
Returns:
평균 손실과 정확도
"""
# 모델만 학습 모드로 설정(분류기 제거)
self.model.train()
total_loss = 0
correct = 0
total = 0
# 로그로만 진행상황 표시
logger.info(f"Epoch {epoch} 학습 시작: 총 {len(train_loader)} 배치")
# 기본 tqdm 설정만 사용 (호환성 최대화)
pbar = tqdm(
train_loader,
desc=f"Epoch {epoch}"
)
# 동적 배치 크기 적용
current_batch_size = self.batch_manager.batch_size
for batch_idx, batch_data in enumerate(pbar):
try:
# BatchEncoding 타입 또는 딕셔너리 형태의 데이터 처리
if isinstance(batch_data, dict) or hasattr(batch_data, 'input_ids'):
# input_ids를 입력 데이터로 사용
data = batch_data.get('input_ids', None)
if data is None and 'input_ids' in batch_data:
data = batch_data['input_ids']
# data가 여전히 None이면 배치 건너뛰기
if data is None:
logger.warning(f"배치 {batch_idx}에서 input_ids를 찾을 수 없습니다. 건너뜁니다.")
continue
data = data.to(self.device)
# 언어 모델링에서는 labels가 있을 수 있음
if 'labels' in batch_data:
target = batch_data['labels'].to(self.device)
else:
# input_ids를 target으로 사용
target = data
# 디버그용 로그
if batch_idx == 0:
logger.info(f"배치 데이터 형태: {type(batch_data).__name__}, 키: {batch_data.keys() if hasattr(batch_data, 'keys') else 'N/A'}")
logger.info(f"입력 데이터 형태: {data.shape if hasattr(data, 'shape') else 'unknown'}")
else:
# 기존 (data, target) 튜플 형태 처리
data, target = batch_data
data, target = data.to(self.device), target.to(self.device)
# 학습 단계 수행
if self.use_amp:
loss, acc = self.train_step_amp(data, target, task_id)
else:
loss, acc = self.train_step(data, target, task_id)
# 배치 크기 업데이트 (OOM 없음)
new_batch_size = self.batch_manager.update()
if new_batch_size != current_batch_size:
current_batch_size = new_batch_size
logger.info(f"배치 크기 업데이트: {current_batch_size}")
# 프로그레스 바 통합 업데이트 (출력 빈도 최적화)
if batch_idx % 50 == 0 or batch_idx == 0: # 첫 배치와 50배치마다 프로그레스 업데이트
# 누적 손실과 정확도 계산
running_loss = total_loss / (batch_idx + 1)
running_acc = correct / max(total, 1)
try:
# 표시와 postfix 모두 한 번에 업데이트
pbar.set_description(f"Epoch {epoch} [L: {loss:.4f}]")
pbar.set_postfix({
'loss': f"{running_loss:.4f}",
'acc': f"{running_acc:.4f}",
'lr': f"{self.optimizer.param_groups[0]['lr']:.2e}",
'bs': current_batch_size
}, refresh=False)
except Exception as e:
# tqdm 버전 호환성 오류 무시
pass
# 디버그 로그 - 첫 배치와 100배치마다만 출력
if batch_idx == 0 or batch_idx % 100 == 0:
if isinstance(data, torch.Tensor):
logger.debug(f"데이터 텐서 정보 - 형태: {data.shape}, 디바이스: {data.device}, 타입: {data.dtype}")
elif hasattr(data, 'input_ids') and hasattr(data.input_ids, 'shape'):
logger.debug(f"입력 ID 텐서 정보 - 형태: {data.input_ids.shape}, 디바이스: {data.input_ids.device}, 타입: {data.input_ids.dtype}")
except torch.cuda.OutOfMemoryError:
# OOM 오류 발생 시 복구
torch.cuda.empty_cache()
logger.warning("OOM 오류 발생, 복구 중...")
# 배치 크기 감소
new_batch_size = self.batch_manager.update(oom_occurred=True)
current_batch_size = new_batch_size
# 해당 배치 건너뛰기
continue
except Exception as e:
# "No inf checks" 오류는 학습에 영향을 주지 않으므로 무시
if "No inf checks" in str(e):
# 오류 무시 - 학습 진행
pass
else:
# 배치 처리 중 오류는 로그에 기록
logger.warning(f"배치 처리 중 오류 발생: {e}")
if self.global_step % 50 == 0: # 로그 양 제한
if isinstance(batch_data, dict) or hasattr(batch_data, 'keys'):
logger.debug(f"배치 키: {list(batch_data.keys()) if hasattr(batch_data, 'keys') else 'N/A'}")
# 배치 데이터 구조 출력
logger.error(f"배치 데이터 구조: {type(batch_data)}")
# 데이터 샘플 출력
if hasattr(batch_data, 'input_ids'):
logger.info(f"input_ids 형태: {type(batch_data.input_ids)}, 크기: {batch_data.input_ids.shape if hasattr(batch_data.input_ids, 'shape') else 'N/A'}")
continue
# 손실 및 정확도 누적
total_loss += loss
# data.size(0) 대신 안전하게 배치 크기 추출
batch_size = data.size(0) if hasattr(data, 'size') else 1
correct += acc * batch_size
total += batch_size
# 진행률 표시 - 예외 처리 추가
try:
pbar.set_postfix({
'loss': f"{loss:.4f}",
'acc': f"{acc*100:.2f}%",
'lr': f"{self.scheduler.get_last_lr()[0]:.6f}"
}, refresh=False) # 성능 향상을 위해 refresh=False 추가
except Exception as e:
# tqdm 버전 호환성 오류 무시
pass
# 체크포인트 확인
if self.spot_handler.should_checkpoint(self.global_step):
self.spot_handler.save_checkpoint(
model=self.model,
optimizer=self.optimizer,
scheduler=self.scheduler,
scaler=self.scaler,
trainer=self,
step=self.global_step
)
# 전역 스텝 업데이트
self.global_step += 1
# 에포크 종료 후 최종 요약 출력
logger.info(f"에포크 {epoch} 완료: 손실={total_loss/(batch_idx+1):.4f}, 정확도={correct/max(1,total):.4f}")
# 최종 메트릭 계산
return total_loss / len(train_loader), correct / total if total > 0 else 0
def train_step(self, data: torch.Tensor, target: torch.Tensor, task_id: int):
"""
단일 훈련 스텝 (표준 정밀도) - 언어 모델 내장 손실만 사용
Args:
data: 입력 데이터
target: 타겟 레이블
task_id: 현재 태스크 ID
Returns:
손실, 정확도
"""
# 언어 모델링 손실 계산을 위해 labels 인자로 target 전달
model_output = self.model(data, labels=target)
# 손실 및 정확도 초기화
loss = None
acc = 0.0
try:
# 모델 내장 손실 사용
if hasattr(model_output, 'loss') and model_output.loss is not None:
loss = model_output.loss
if self.global_step % 10 == 0:
logger.info(f"[INFO] 언어 모델 내장 손실 사용: {loss.item():.4f}")
# 정확도 계산 (선택적)
if hasattr(model_output, 'logits'):
# 마지막 토큰 위치의 예측값만 사용하여 간단한 정확도 계산
if len(target.shape) > 1 and len(model_output.logits.shape) > 2:
# 실제 타겟의 마지막 유효한 토큰 인덱스 계산 (패딩 제외)
pred = model_output.logits[:, :-1].argmax(dim=-1)
valid_target = target[:, 1:]
valid_mask = (valid_target != -100) & (valid_target != 0) # 패딩 토큰 및 특수 토큰 제외
if valid_mask.sum() > 0:
correct = (pred == valid_target) & valid_mask
acc = correct.sum().float() / valid_mask.sum()
else:
# 내장 손실이 없는 경우 - 예외 상황
available_attrs = [attr for attr in dir(model_output) if not attr.startswith('_')]
logger.error(f"model_output 속성: {available_attrs}")
raise ValueError(f"모델에서 언어 모델링 손실을 찾을 수 없음: {type(model_output)}")
# logits이 있는 경우 직접 손실 계산
except Exception as e:
logger.error(f"배치 처리 중 오류 발생: {e}")
logger.error(f"입력 데이터 구조: {type(data)}")
# 대체 접근 방법 시도
logger.warning(f"기본 방식 손실 계산 중 오류 발생: {str(e)}")
logger.warning("내장 손실 계산 재시도...")
# 데이터에 레이블이 포함된 경우 (HuggingFace 형식)
if isinstance(data, dict) and 'labels' in data:
model_output = self.model(**data) # 레이블 포함하여 다시 계산
else:
# 다양한 방식 시도
try:
model_output = self.model(data, labels=target) # 레이블 명시적 전달
except Exception as inner_e:
logger.error(f"레이블 전달 시도 실패: {inner_e}")
# 마지막 시도: 레이블 없이 실행 후 손실 수동 계산
model_output = self.model(data)
if hasattr(model_output, 'loss'):
loss = model_output.loss
logger.info(f"[복구] 내장 손실 사용: {loss.item():.4f}")
else:
logger.error("내장 손실 계산 실패, 학습 중단 필요")
raise RuntimeError("모델의 언어 모델링 손실 계산 실패")
# 학습률 스케줄러 스텝 (있는 경우)
if self.scheduler is not None:
self.scheduler.step()
# MER 메모리 업데이트
self.mer.reservoir_sampling(data, target)
# 결과 반환: 손실 값과 정확도(있는 경우)
return loss.item(), acc
def train_step_amp(self, data, target, task_id: int):
"""
단일 훈련 스텝 (혼합 정밀도) - 언어 모델 내장 손실만 사용
Args:
data: 입력 데이터 (torch.Tensor 또는 transformers.BatchEncoding)
target: 타겟 레이블 (torch.Tensor 또는 transformers.BatchEncoding)
task_id: 현재 태스크 ID
Returns:
손실 값, 정확도
"""
self.optimizer.zero_grad()
# 데이터 타입 체크 및 처리
from transformers.tokenization_utils_base import BatchEncoding
# BatchEncoding 객체 처리
if isinstance(data, BatchEncoding):
# BatchEncoding 로깅 최적화: 초기 한 번만 info로 기록, 이후에는 매우 낮은 빈도로 debug로 기록
if self.global_step == 0: # 학습 첫 시작시에만 info 레벨로 기록
logger.info(f"BatchEncoding 데이터 처리: 키={list(data.keys())}, 디바이스={data.get('input_ids').device if 'input_ids' in data else '알 수 없음'}")
elif self.global_step % 5000 == 0: # 이후에는 5000 스텝마다 debug 레벨로만 기록
logger.debug(f"BatchEncoding 데이터 처리: 키={list(data.keys())}, 디바이스={data.get('input_ids').device if 'input_ids' in data else '알 수 없음'}")
# 필요한 경우에만 input_ids 추출
if 'input_ids' in data:
input_tensor = data['input_ids']
else:
available_keys = list(data.keys())
logger.error(f"BatchEncoding에서 input_ids를 찾을 수 없음. 사용 가능한 키: {available_keys}")
raise ValueError(f"BatchEncoding에 input_ids가 없습니다. 사용 가능한 키: {available_keys}")
else:
# 이미 텐서인 경우 그대로 사용
input_tensor = data
# target도 동일하게 처리
if isinstance(target, BatchEncoding):
if 'labels' in target:
labels_tensor = target['labels']
elif 'input_ids' in target:
# fallback: input_ids를 labels로 사용
labels_tensor = target['input_ids']
else:
available_keys = list(target.keys())
logger.error(f"BatchEncoding에서 labels를 찾을 수 없음. 사용 가능한 키: {available_keys}")
raise ValueError(f"BatchEncoding에 labels가 없습니다. 사용 가능한 키: {available_keys}")
else:
# 이미 텐서인 경우 그대로 사용
labels_tensor = target
# 혼합 정밀도 연산 - 최신 PyTorch API 형식 사용
with autocast('cuda'):
# 언어 모델링 손실 계산을 위해 labels 인자로 target 전달
# DeepSeek-Coder 및 다른 Hugging Face 모델은 labels를 통한 직접 손실 계산 지원
model_output = self.model(input_tensor, labels=labels_tensor)
# 언어 모델 내장 손실 사용
try:
# 1. 모델이 내장 손실을 제공하는 경우 (허깅페이스 transformers 모델)
if hasattr(model_output, 'loss') and model_output.loss is not None:
# 손실을 추출하고 그래디언트 추적을 위해 명시적으로 requires_grad=True 설정
loss = model_output.loss
# 그래디언트 추적이 필요하지만 없는 경우
if not loss.requires_grad:
# 새로운 텐서로 복사하고 그래디언트 추적 활성화
loss = loss.detach().clone().requires_grad_(True)
# 로그 출력 빈도 조절 (500배치마다 한 번만 출력)
if self.global_step % 500 == 0:
logger.info(f"LM 내장 손실 사용: {loss.item():.4f}, requires_grad={loss.requires_grad}, 학습률={self.scheduler.get_last_lr()[0]:.6f}")
else:
# 내장 손실이 없는 경우 - 로깅 후 예외 발생
available_attrs = [attr for attr in dir(model_output) if not attr.startswith('_')]
logger.error(f"model_output 속성: {available_attrs}")
raise ValueError(f"모델에서 언어 모델링 손실을 찾을 수 없음: {type(model_output)}")
except Exception as e:
logger.error(f"손실 계산 중 오류 발생: {e}")
logger.error(f"model_output 타입: {type(model_output)}")
raise RuntimeError(f"LM 손실 계산 실패: {e}")
# 정확도 계산 (선택적) - 완전한 분류기 대신 단순한 정확도 추정 사용
acc = 0.0
try:
if hasattr(model_output, 'logits'):
# 마지막 토큰 위치의 예측값만 사용하여 간단한 정확도 계산
# (실제 LM 성능은 perplexity 등으로 별도 평가해야 함)
if len(target.shape) > 1 and len(model_output.logits.shape) > 2:
# 실제 타겟의 마지막 유효한 토큰 인덱스 계산 (패딩 제외)
pred = model_output.logits[:, :-1].argmax(dim=-1)
valid_target = target[:, 1:]
valid_mask = (valid_target != -100) & (valid_target != 0) # 패딩 토큰 및 특수 토큰 제외
if valid_mask.sum() > 0:
correct = (pred == valid_target) & valid_mask
acc = correct.sum().float() / valid_mask.sum()
except Exception as e:
logger.warning(f"정확도 계산 중 오류 발생 (무시됨): {e}")
acc = 0.0
# EWC 손실 (이전 태스크 보존) - 필요한 경우
try:
# EWC 손실 계산 전에 자원 확보
if hasattr(self, 'ewc') and self.ewc is not None:
ewc_loss = self.ewc.compute_ewc_loss()
# 그래디언트 추적이 필요하면 자동으로 설정
if not ewc_loss.requires_grad and ewc_loss.item() != 0.0:
# 가능한 한 조용히 처리 - 로그 없이 자동 변환
ewc_loss = ewc_loss.detach().clone().requires_grad_(True)
# 손실값 유효성 검사
if torch.isfinite(ewc_loss).all() and torch.isfinite(loss).all():
loss = loss + ewc_loss
if self.global_step % 20 == 0 and ewc_loss.item() > 0.0001:
# 유의미한 EWC 손실이 있을 때만 기록
logger.info(f"LM 손실: {loss.item():.4f}, EWC 손실: {ewc_loss.item():.4f}")
else:
# EWC가 없는 경우 아무것도 하지 않음
pass
except Exception as e:
# 중요하지 않은 예외는 조용히 무시
pass
# 역전파 전에 손실이 그래디언트 추적을 요구하는지 확인
try:
if not loss.requires_grad:
logger.warning(f"[안전 조치] 손실이 그래디언트 추적을 요구하지 않음. 명시적으로 requires_grad=True 설정")
loss = loss.detach().clone().requires_grad_(True)
# 스케일링된 역전파
scaled_loss = self.scaler.scale(loss)
scaled_loss.backward()
if self.global_step % 50 == 0:
logger.debug(f"backward 호출 성공: 손실={loss.item():.4f}, requires_grad={loss.requires_grad}")
except Exception as e:
logger.error(f"역전파 오류: {e}")
# 간단한 대체 역전파 시도 (응급 조치)
try:
if hasattr(loss, 'backward'):
loss.backward()
logger.info("대체 역전파 방법 사용")
except Exception as e2:
logger.error(f"대체 역전파도 실패: {e2}")
# 실패하면 현재 배치 건너뛼
# 그래디언트 클리핑 (선택적)
if self.config.get('grad_clip', 0) > 0:
try:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.get('grad_clip', 1.0)
)
except RuntimeError as e:
logger.warning(f"그래디언트 클리핑 중 오류 (무시함): {e}")
# 스케일러를 통한 옵티마이저 스텝 - 오류 처리 개선
try:
# 옵티마이저 스텝 수행
self.scaler.step(self.optimizer)
self.scaler.update()
# 학습률 스케쥴러 업데이트
if self.scheduler is not None:
self.scheduler.step()
except RuntimeError as e:
error_msg = str(e)
if "No inf checks" in error_msg:
# 이 오류는 학습에 실질적 영향이 없으므로 로그 없이 무시
pass
elif "element 0 of tensors does not require grad" in error_msg:
# 그래디언트 추적 문제 - 다음 배치에서 문제 해결 시도
logger.warning("[문제 해결 시도] 텐서에 그래디언트 추적이 활성화되지 않은 문제")
# 옵티마이저 상태 초기화 (경계 상태)
self.optimizer.zero_grad()
else:
# 다른 중요한 오류는 기록 후 넘김
logger.warning(f"옵티마이저 스텝 중 오류 발생: {e}")
# 특정 유형의 오류만 무시하고 다음 배치로 진행
if "CUDA out of memory" in error_msg or "c10:" in error_msg:
logger.error("추가 디버그 정보: CUDA 특정 오류 문제 무시")
# 메모리 정리
torch.cuda.empty_cache()
else:
raise
# 학습률 스케줄러 스텝
self.scheduler.step()
# MER 메모리 업데이트 (한번만 실행)
self.mer.reservoir_sampling(data, target)
return loss.item(), acc
def evaluate(self, data_loader: DataLoader):
"""
모델 평가
Args:
data_loader: 평가 데이터 로더
Returns:
평균 손실, 정확도
"""
self.model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, target in data_loader:
data, target = data.to(self.device), target.to(self.device)
try:
# 언어 모델의 내장 손실 함수 사용
outputs = self.model(data, labels=target)
loss = outputs.loss
total_loss += loss.item()
# 로그엣과 평균 정확도 계산 (간단한 버전)
logits = outputs.logits # [batch_size, sequence_length, vocab_size]
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = target[:, 1:].contiguous()
# 참고: 여기서의 정확도는 단어 레벨 정확도의 대략적인 추정
flattened_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
flattened_shift_labels = shift_labels.view(-1)
# 유효한 토큰에 대해서만 정확도 계산 (패딩 토큰 제외)
valid_indices = (flattened_shift_labels != -100)
if valid_indices.sum() > 0:
valid_logits = flattened_shift_logits[valid_indices]
valid_labels = flattened_shift_labels[valid_indices]
predictions = valid_logits.argmax(dim=-1)
correct += (predictions == valid_labels).sum().item()
total += valid_indices.sum().item()
except Exception as e:
logger.error(f"평가 중 오류 발생: {str(e)}")
continue
# 평균 손실 및 정확도 계산
avg_loss = total_loss / len(data_loader) if len(data_loader) > 0 else 0
accuracy = correct / total if total > 0 else 0
return avg_loss, accuracy
def log_metrics(self, epoch: int, train_loss: float, train_acc: float,
val_loss: float, val_acc: float, test_results: Dict[int, float],
metrics: Dict[str, float]):
"""
메트릭 로깅
Args:
epoch: 현재 에포크
train_loss: 훈련 손실
train_acc: 훈련 정확도
val_loss: 검증 손실
val_acc: 검증 정확도
test_results: 이전 태스크 테스트 결과
metrics: 계산된 메트릭
"""
# 콘솔 로깅
log_str = f"Epoch {epoch}: "
log_str += f"Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.2f}% | "
log_str += f"Val Loss: {val_loss:.4f}, Acc: {val_acc*100:.2f}% | "
for task_id, acc in test_results.items():
log_str += f"Task {task_id} Acc: {acc*100:.2f}% | "
log_str += f"ACC: {metrics['ACC']*100:.2f}%, BWT: {metrics['BWT']*100:.2f}%, "
log_str += f"Stability Gap: {metrics['stability_gap']*100:.2f}%"
logger.info(log_str)
# TensorBoard 로깅
step = epoch + self.current_task_id * self.config.get('max_epochs', 10)
self.writer.add_scalar('Loss/train', train_loss, step)
self.writer.add_scalar('Loss/val', val_loss, step)
self.writer.add_scalar('Accuracy/train', train_acc, step)
self.writer.add_scalar('Accuracy/val', val_acc, step)
for task_id, acc in test_results.items():
self.writer.add_scalar(f'Accuracy/task_{task_id}', acc, step)
self.writer.add_scalar('Metrics/ACC', metrics['ACC'], step)
self.writer.add_scalar('Metrics/BWT', metrics['BWT'], step)
self.writer.add_scalar('Metrics/FWT', metrics['FWT'], step)
self.writer.add_scalar('Metrics/stability_gap', metrics['stability_gap'], step)
self.writer.add_scalar('Metrics/recovery_ratio', metrics['recovery_ratio'], step)
# 학습률 로깅
self.writer.add_scalar('LR', self.scheduler.get_last_lr()[0], step)
def save_model(self, filename: str):
"""
모델 저장
Args:
filename: 저장할 파일 이름
"""
save_path = os.path.join(self.config.get('checkpoint_dir', 'checkpoints'), filename)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
state_dict = {
'model': self.model.state_dict(),
# 분류기 제거 - 언어 모델 내장 손실만 사용
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict() if self.scheduler else None,
'scaler': self.scaler.state_dict() if self.scaler else None,
'metrics': self.metrics.get_state_dict(),
'config': self.config,
'global_step': self.global_step,
'current_task_id': self.current_task_id,
'best_accuracies': self.best_accuracies,
'ewc': self.ewc.get_state_dict(),
'mer': self.mer.get_state_dict()
}
torch.save(state_dict, save_path)
logger.info(f"모델이 저장되었습니다: {save_path}")
def load_model(self, filename: str):
"""
모델 로드
Args:
filename: 로드할 파일 이름
"""
load_path = os.path.join(self.config.get('checkpoint_dir', 'checkpoints'), filename)
if not os.path.exists(load_path):
logger.warning(f"로드할 모델 파일이 없습니다: {load_path}")
return False
state_dict = torch.load(load_path, map_location=self.device)
self.model.load_state_dict(state_dict['model'])
# 분류기 로드 제거 - 언어 모델만 사용
self.optimizer.load_state_dict(state_dict['optimizer'])
if 'scheduler' in state_dict and state_dict['scheduler'] and self.scheduler:
self.scheduler.load_state_dict(state_dict['scheduler'])
if 'scaler' in state_dict and state_dict['scaler'] and self.scaler:
self.scaler.load_state_dict(state_dict['scaler'])
if 'metrics' in state_dict:
self.metrics.load_state_dict(state_dict['metrics'])
if 'ewc' in state_dict:
self.ewc.load_state_dict(state_dict['ewc'])
if 'mer' in state_dict:
self.mer.load_state_dict(state_dict['mer'])
self.global_step = state_dict.get('global_step', 0)
self.current_task_id = state_dict.get('current_task_id', 0)
self.best_accuracies = state_dict.get('best_accuracies', {})
logger.info(f"모델을 로드했습니다: {load_path}")
return True
def get_state_dict(self):
"""상태 저장을 위한 딕셔너리 반환"""
return {
'metrics': self.metrics.get_state_dict(),
'global_step': self.global_step,
'current_task_id': self.current_task_id,
'best_accuracies': self.best_accuracies,
'ewc': self.ewc.get_state_dict(),
'mer': self.mer.get_state_dict()
}
def load_state_dict(self, state_dict):
"""저장된 상태 로드"""
if 'metrics' in state_dict:
self.metrics.load_state_dict(state_dict['metrics'])
if 'ewc' in state_dict:
self.ewc.load_state_dict(state_dict['ewc'])
if 'mer' in state_dict:
self.mer.load_state_dict(state_dict['mer'])
self.global_step = state_dict.get('global_step', 0)
self.current_task_id = state_dict.get('current_task_id', 0)
self.best_accuracies = state_dict.get('best_accuracies', {})
# 메인 함수
# wandb 명시적 비활성화 (스크립트 맨 위에 추가)
os.environ["WANDB_DISABLED"] = "true"
def main():
"""메인 실행 함수"""
parser = argparse.ArgumentParser(description='CapaBoost + Continual Learning 통합 시스템')
parser.add_argument('--config', type=str, default='config.yaml', help='설정 파일 경로')
parser.add_argument('--mode', type=str, choices=['complete', 'prompt', 'comment', 'error_fix'], default='prompt',
help='학습 모드: complete(자동완성), prompt(일반 프롬프트), comment(주석), error_fix(오류 수정)')
parser.add_argument('--seed', type=int, default=42, help='랜덤 시드')
parser.add_argument('--resume', action='store_true', help='이전 체크포인트에서 재개')
parser.add_argument('--resume_from_checkpoint', type=str, help='특정 체크포인트 경로에서 재개')
parser.add_argument('--enable_capaboost', action='store_true', help='CapaBoost 기능 활성화')
args = parser.parse_args()
# 설정 파일 로드
with open(args.config, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 명령행 인수로 설정 덮어쓰기
if args.seed:
config['seed'] = args.seed
# 학습 모드 추가
config['mode'] = args.mode
logger.info(f"\ud559\uc2b5 \ubaa8\ub4dc: {args.mode}")
# 모드별 출력 경로 설정
if args.mode == 'complete':
config['output_dir'] = config.get('output_dir', '../models/autocomplete-finetuned/')
logger.info("[1차] 코드 자동완성 학습 모드 (FIM 형식 지원)")
elif args.mode == 'prompt':
config['output_dir'] = config.get('output_dir', '../models/prompt-finetuned/')
logger.info("[2차] 일반 프롬프트 기반 코드 생성 모드")
elif args.mode == 'comment':
config['output_dir'] = config.get('output_dir', '../models/comment-finetuned/')
logger.info("[3차] 주석 기반 코드 생성 모드")
elif args.mode == 'error_fix':
config['output_dir'] = config.get('output_dir', '../models/error-fix-finetuned/')
logger.info("[4차] 코드 오류 설명 및 수정 모드")
# CapaBoost 활성화 여부 로깅 및 설정에 추가
capaboost_enabled = args.enable_capaboost or (os.environ.get('ENABLE_CAPABOOST', '0') == '1')
if capaboost_enabled:
logger.info("CapaBoost 기능이 활성화되었습니다.")
config['enable_capaboost'] = True
else:
logger.info("CapaBoost 기능이 비활성화되었습니다.")
config['enable_capaboost'] = False
# 통합 학습기 초기화
learner = ContinualLearner(config)
# 체크포인트 경로 설정
checkpoint_path = None
if args.resume_from_checkpoint:
checkpoint_path = args.resume_from_checkpoint
logger.info(f"\uccb4\ud06c\ud3ec\uc778\ud2b8 \uacbd\ub85c\uc5d0\uc11c \uc7ac\uac1c: {checkpoint_path}")
elif args.resume:
checkpoint_path = "latest" # 최근 체크포인트에서 재개
logger.info("\ucd5c\uc2e0 \uccb4\ud06c\ud3ec\uc778\ud2b8\uc5d0\uc11c \uc7ac\uac1c")
# 체크포인트 로드
if checkpoint_path:
if checkpoint_path == "latest":
learner.spot_handler.load_latest_checkpoint(
model=learner.model,
optimizer=learner.optimizer,
scheduler=learner.scheduler,
scaler=learner.scaler,
trainer=learner
)
else:
learner.spot_handler.load_checkpoint(
checkpoint_path,
model=learner.model,
optimizer=learner.optimizer,
scheduler=learner.scheduler,
scaler=learner.scaler,
trainer=learner
)
# 모드별 데이터 전처리 및 포맷 적용
logger.info(f"모드별 데이터 처리: {args.mode} 모드용 데이터셋 준비 중...")
# 데이터 로더 생성 - 모드별 처리 로직 적용
# train.py의 데이터 형식 처리 로직을 가져와서 구현
train_dataset = None
val_dataset = None
# 데이터 셋 경로 설정 - 사용자 지정 경로로 변경
data_path = os.path.expanduser("~/deepseek-coder/data")
logger.info(f"데이터 디렉토리 경로: {data_path}")
# 모드별 추가 로깅
if args.mode == 'complete':
logger.info(f"자동완성(FIM) 모드 사용")
elif args.mode == 'prompt':
logger.info(f"프롬프트 생성 모드 사용")
elif args.mode == 'comment':
logger.info(f"주석 기반 생성 모드 사용")
elif args.mode == 'error_fix':
logger.info(f"오류 수정 모드 사용")
# 데이터셋 로드 (datasets 라이브러리 사용)
if not DATASETS_AVAILABLE:
raise ImportError("datasets 라이브러리를 찾을 수 없습니다. 'pip install datasets'를 실행하여 설치하세요.")
try:
logger.info(f"데이터셋 로드 중: {data_path}")
# 데이터셋 파일 존재 여부 확인 - 훈련 데이터만 체크 (.jsonl 형식)
train_path = os.path.join(data_path, 'train.jsonl')
if not os.path.exists(train_path):
logger.error(f"훈련 데이터 파일을 찾을 수 없습니다: {train_path}")
raise FileNotFoundError(f"훈련 데이터 파일을 찾을 수 없습니다: {train_path}")
# 훈련 데이터 파일만 사용 (검증 데이터는 후에 분리)
data_files = {'train': train_path}
raw_dataset = load_dataset(
'json',
data_files=data_files,
cache_dir=config.get('cache_dir', '.cache'),
# jsonl 형식 처리를 위한 옵션
streaming=False # 지정한 경우에만 False로 설정하여 메모리에 로드
)
# 훈련/검증 분리
validation_split_percentage = config.get('validation_split_percentage', 10)
logger.info(f"훈련 데이터에서 검증 데이터 {validation_split_percentage}% 분리 시작")
# 기존 validation 세트가 없는 경우에만 분리 수행
if 'validation' not in raw_dataset:
split_datasets = raw_dataset["train"].train_test_split(
test_size=validation_split_percentage / 100,
seed=config.get('seed', 42)
)
raw_dataset["train"] = split_datasets["train"]
raw_dataset["validation"] = split_datasets["test"]
logger.info(f"최종 데이터셋 크기 - 훈련: {len(raw_dataset['train'])}, 검증: {len(raw_dataset['validation'])}")
# 모드별 적합한 데이터 전처리
# 기본 DeepSeek-Coder 모델 경로 지정 (절대경로 사용)
default_model_path = "/home/ubuntu/deepseek-coder/model_cache/deepseek-coder-6.7b-instruct"
# 구성에서 모델 경로 가져오기 (model_name 또는 model_name_or_path)
model_path = config.get('model_name', config.get('model_name_or_path', default_model_path))
logger.info(f"토크나이저 로드 중: {model_path}")
try:
# 로컬 경로인지 확인 (절대 경로 또는 ~로 시작하는 경로)
is_local_path = os.path.isabs(model_path) or model_path.startswith('~') or model_path.startswith('/')
# 로컬 경로면 local_files_only=True 옵션 추가
tokenizer = AutoTokenizer.from_pretrained(
model_path,
cache_dir=config.get('cache_dir', '.cache'),
use_fast=True,
local_files_only=is_local_path, # 로컬 파일만 사용
trust_remote_code=True # 안전한 환경에서 원격 코드 신뢰
)
logger.info(f"토크나이저 로드 성공: {model_path}")
except Exception as e:
logger.error(f"토크나이저 로드 오류: {e}")
logger.info(f"기본 모델 {default_model_path}로 재시도 중...")
# 기본 DeepSeek-Coder 모델로 재시도 (로컬 파일 사용 옵션 추가)
tokenizer = AutoTokenizer.from_pretrained(
default_model_path,
cache_dir=config.get('cache_dir', '.cache'),
use_fast=True,
local_files_only=True, # 로컬 파일만 사용
trust_remote_code=True # 안전한 환경에서 원격 코드 신뢰
)
# 최대 길이 설정
max_length = config.get('max_length', 2048)
# 데이터 형식을 자동으로 감지하는 함수 추가
def detect_format(example):
"""데이터 형식을 감지하는 함수"""
if "messages" in example:
return "chat"
elif "prompt" in example and "completion" in example:
return "prompt_completion"
elif "prefix_code" in example and "suffix_code" in example and "comment" in example and "target_code" in example:
# 주석 기반 코드 생성 형식 - 새로운 데이터 구조
return "comment_to_code"
elif "error_context" in example and "explanation" in example:
return "error_explanation"
elif "error_context" in example and "fixed_code_snippet" in example:
# error_fix 형식 감지 개선 - buggy_code_snippet이 error_context 내부에 있는 경우도 처리
error_context = example["error_context"]
if isinstance(error_context, dict) and "buggy_code_snippet" in error_context:
return "error_fix"
# 또는 buggy_code_snippet이 최상위 필드에 있는 경우
elif "buggy_code_snippet" in example:
return "error_fix"
elif "instruction" in example and "input" in example and "output" in example:
return "instruction_input_output"
elif "content" in example and args.mode == "complete":
# complete 모드용 단순 content 필드만 있는 경우 (FIM 형식일 수 있음)
return "complete"
elif "comment" in example and "code" in example and args.mode == "comment":
# 주석과 코드가 있는 단순 형식
return "comment_code"
else:
# 데이터 구조 로깅
logger.warning(f"⚠️ 알 수 없는 데이터 형식: {list(example.keys())}")
return "unknown"
# 데이터 형식 감지 (첫 번째 샘플로 판단)
try:
sample = raw_dataset["train"][0]
data_format = detect_format(sample)
logger.info(f"✅ 감지된 데이터 형식: {data_format}")
except Exception as e:
logger.error(f"데이터 형식 감지 중 오류 발생: {e}")
data_format = "unknown"
# FIM 형식 사용 여부를 미리 확인 (로그 중복 방지)
has_fim_format = False
if data_format == "prompt_completion" and args.mode == "comment":
# 첫 몇 개 샘플 검사하여 FIM 태그 사용 여부 확인
sample_check_count = min(5, len(raw_dataset["train"]))
for i in range(sample_check_count):
sample_prompt = raw_dataset["train"][i]["prompt"]
if "<|fim begin|>" in sample_prompt or "<|fim hole|>" in sample_prompt or "<|fim end|>" in sample_prompt:
logger.info("✅ FIM 형식 주석 기반 모델 데이터 감지됨")
has_fim_format = True
break
# 모드별 전처리 함수 정의
def process_complete_data(examples):
# FIM (Fill In the Middle) 형식 처리
results = tokenizer(
examples['content'],
truncation=True,
max_length=max_length,
return_attention_mask=False
)
return results
def process_prompt_data(examples):
# 프롬프트 기반 생성 형식 처리
results = tokenizer(
examples['instruction'] + '\n' + examples['input'] + '\n' + examples['output'],
truncation=True,
max_length=max_length,
return_attention_mask=False
)
return results
def process_comment_data(examples):
# 주석 기반 코드 생성 처리
results = tokenizer(
examples['comment'] + '\n' + examples['code'],
truncation=True,
max_length=max_length,
return_attention_mask=False
)
return results
def process_error_fix_data(examples):
# 오류 수정 데이터 처리
results = tokenizer(
examples['error_description'] + '\n' + examples['buggy_code'] + '\n' + examples['fixed_code'],
truncation=True,
max_length=max_length,
return_attention_mask=False
)
return results
# 향상된 토큰화 함수 - train.py의 로직 통합
def tokenize_function(examples):
# 모드별, 데이터 형식별 처리를 통합
texts = []
# 샘플 수 결정 (다양한 형식 지원)
sample_count = 0
if "content" in examples and args.mode == 'complete':
sample_count = len(examples["content"])
elif "messages" in examples:
sample_count = len(examples["messages"])
elif "prompt" in examples:
sample_count = len(examples["prompt"])
elif "prefix_code" in examples:
sample_count = len(examples["prefix_code"])
elif "instruction" in examples:
sample_count = len(examples["instruction"])
elif "comment" in examples and "code" in examples:
sample_count = len(examples["comment"])
elif "error_description" in examples and "buggy_code" in examples:
sample_count = len(examples["error_description"])
else:
logger.error("❌ 지원되지 않는 데이터 형식")
return {"input_ids": []}
# 각 샘플 처리
for i in range(sample_count):
# 데이터 형식 및 모드별 처리
if data_format == "complete" or (args.mode == 'complete' and "content" in examples):
# 자동완성(complete) 모드
if "content" in examples:
text = examples["content"][i]
texts.append(text)
elif data_format == "prompt_completion":
# 주석 키워드 기반 코드 생성 형식
prompt = examples["prompt"][i] if "prompt" in examples else ""
completion = examples["completion"][i] if "completion" in examples else ""
# FIM 형식 처리 (3차 주석 기반 모델)
if args.mode == "comment" and ("<|fim begin|>" in prompt or "<|fim hole|>" in prompt or "<|fim end|>" in prompt):
# FIM 형식 그대로 유지하고 ChatML 형식으로 래핑
text = f"assistant\n{prompt}\n\nassistant\n{completion}\n"
else:
# 일반 프롬프트를 사용자 메시지로, 완성을 어시스턴트 메시지로 변환
text = f"assistant\n{prompt}\n\nassistant\n{completion}\n"
texts.append(text)
elif data_format == "comment_to_code" or data_format == "comment_code" or args.mode == 'comment':
# 주석 기반 코드 생성 형식
if data_format == "comment_to_code":
prefix_code = examples["prefix_code"][i] if "prefix_code" in examples else ""
suffix_code = examples["suffix_code"][i] if "suffix_code" in examples else ""
comment = examples["comment"][i]
target_code = examples["target_code"][i] if "target_code" in examples else ""
# 사용자 입력 구성 (주석과 코드 컨텍스트)
user_text = f"주석에 따라 적절한 코드를 생성해주세요.\n\n"
user_text += f"주석: {comment}\n\n"
user_text += "\n이전 코드:\n"
user_text += f"{prefix_code}\n"
user_text += "// 여기에 코드를 삽입해야 함 //\n"
user_text += f"{suffix_code}"
# ChatML 형식으로 변환
text = f"assistant\n{user_text}\n\nassistant\n{target_code}\n"
else:
# 간단한 주석-코드 형식
comment = examples["comment"][i]
code = examples["code"][i]
text = f"assistant\n주석: {comment}\n\nassistant\n{code}\n"
texts.append(text)
elif data_format == "error_fix" or args.mode == 'error_fix':
# 에러 수정 형식
if "error_description" in examples and "buggy_code" in examples and "fixed_code" in examples:
# 단순 에러 수정 형식
error_desc = examples["error_description"][i]
buggy_code = examples["buggy_code"][i]
fixed_code = examples["fixed_code"][i]
text = f"assistant\n오류 설명: {error_desc}\n\nassistant\n{buggy_code}\n\nassistant\n{fixed_code}\n"
texts.append(text)
elif "error_context" in examples:
# 복잡한 에러 수정 형식
error_context = examples["error_context"][i]
fixed_code = examples["fixed_code_snippet"][i] if "fixed_code_snippet" in examples else ""
# 에러 컨텍스트 정보 구성
error_log = error_context.get("error_log", "") if isinstance(error_context, dict) else ""
language = error_context.get("language", "") if isinstance(error_context, dict) else ""
# buggy_code가 error_context 안에 있는 경우와 외부에 있는 경우 모두 처리
if isinstance(error_context, dict) and "buggy_code_snippet" in error_context:
buggy_code = error_context["buggy_code_snippet"]
elif "buggy_code_snippet" in examples:
buggy_code = examples["buggy_code_snippet"][i]
else:
buggy_code = ""
logger.warning("⚠️ buggy_code_snippet을 찾을 수 없습니다")
# 사용자 입력 구성
user_text = f"다음 {language} 코드의 에러를 수정해주세요:\n\n에러 로그:\n{error_log}\n\n코드:\n{buggy_code}"
text = f"assistant\n{user_text}\n\nassistant\n{fixed_code}\n"
texts.append(text)
else:
# 알 수 없는 형식은 원시 데이터 텍스트를 그대로 사용
logger.warning(f"⚠️ 지원되지 않는 데이터 형식: {data_format}, 모드: {args.mode}")
# 첫 번째 필드만 사용
if len(examples) > 0:
first_key = list(examples.keys())[0]
text = str(examples[first_key][i])
texts.append(text)
else:
texts.append("")
# 토크나이징
model_inputs = tokenizer(
texts,
truncation=True,
max_length=max_length,
return_attention_mask=False,
padding=False
)
return model_inputs
# 모드와 데이터 형식에 따른 처리 함수 선택
logger.info(f"모드: {args.mode}, 감지된 데이터 형식: {data_format}")
# 데이터셋 전처리 적용 - 향상된 토큰화 함수 사용
# 데이터셋 토큰화 로깅 개선
logger.info(f"✨ {args.mode} 데이터셋 토큰화 시작 (num_proc={config.get('preprocessing_num_workers', 4)})")
# tqdm 설정 관련 코드
# 라이브러리 버전 호환성 문제로 인해 tqdm_kwargs 매개변수를 사용하지 않음
for bar in tqdm._instances:
bar.close()
# 토큰화 실행 - 기본 tqdm 설정 사용
# 특히 AWS의 datasets 라이브러리에서는 tqdm_kwargs를 지원하지 않을 수 있음
logger.info("토큰화 시작 - 토큰화 진행중 이후 로그를 확인하세요")
tokenized_dataset = raw_dataset.map(
tokenize_function,
batched=True,
remove_columns=raw_dataset['train'].column_names,
num_proc=config.get('preprocessing_num_workers', 4),
desc=f"{args.mode} 데이터셋 토큰화 중"
# tqdm_kwargs 매개변수 제거 - 호환성 문제 해결
)
logger.info(f"✅ {args.mode} 데이터셋 토큰화 완료: 훈련={len(tokenized_dataset['train'])}개, 검증={len(tokenized_dataset['validation'])}개")
# DataLoader 생성을 위한 준비
train_dataset = tokenized_dataset['train']
val_dataset = tokenized_dataset['validation']
# DataCollatorForLanguageModeling 로드
try:
from transformers import DataCollatorForLanguageModeling
logger.info("DataCollatorForLanguageModeling 로드 성공")
except ImportError as e:
logger.error(f"DataCollatorForLanguageModeling 로드 오류: {e}")
raise ImportError("transformers 라이브러리에서 DataCollatorForLanguageModeling을 로드할 수 없습니다. 'pip install transformers'를 실행하여 최신 버전으로 업그레이드하세요.")
# DataLoader 생성
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# collate_fn으로 커스텀 함수 사용하여 BatchEncoding 처리 보장
def safe_collate_fn(features):
try:
# 원래 collator 호출
batch = data_collator(features)
# 디버깅: 첫 번째 배치만 로깅
if logger.level <= logging.DEBUG and random.random() < 0.01: # 1% 샘플링
logger.debug(f"배치 타입: {type(batch)}, 키: {list(batch.keys()) if hasattr(batch, 'keys') else 'N/A'}")
if hasattr(batch, 'input_ids'):
logger.debug(f"input_ids 형태: {batch.input_ids.shape if hasattr(batch.input_ids, 'shape') else 'N/A'}")
return batch
except Exception as e:
logger.error(f"Collate 함수 오류: {e}")
# 긴급 폴백: 원본 특성 그대로 반환
return features
# 훈련용 데이터 로더 생성 - 상세 정보 로깅
batch_size = config.get('batch_size', 32)
num_workers = config.get('dataloader_num_workers', 4)
logger.info(f"학습 DataLoader 초기화: 배치 크기={batch_size}, 작업자={num_workers}")
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
collate_fn=safe_collate_fn, # 안전한 collate 함수 사용
shuffle=True,
pin_memory=True,
num_workers=num_workers
)
# 학습 로더 생성 후 추가 정보
logger.info(f"학습 DataLoader 생성 완료: 배치 크기={batch_size}, 배치 개수={len(train_loader):,}")
val_loader = DataLoader(
val_dataset,
batch_size=config.get('batch_size', 32),
collate_fn=safe_collate_fn, # 안전한 collate 함수 사용
shuffle=False,
pin_memory=True,
num_workers=config.get('dataloader_num_workers', 4)
)
# 로그에 데이터 정보 기록 - 상세 정보 추가
total_steps = len(train_loader)
logger.info(f"훈련 데이터 요약 - 학습용: {len(train_dataset):,} 샘플, 검증용: {len(val_dataset):,} 샘플")
logger.info(f"학습 배치 정보 - 배치당 {batch_size}개 샘플 포함, 총 {total_steps:,}개 배치")
logger.info(f"학습 예상 시간: 에포크당 최소 {(total_steps*1.5)/60:.1f}분 (배치당 1.5초 가정)")
# 지속학습을 위한 태스크 처리
# 모드별로 데이터를 별도의 태스크로 간주
task_id = {'complete': 0, 'prompt': 1, 'comment': 2, 'error_fix': 3}.get(args.mode, 0)
# 테스트 로더는 검증 데이터로 대체 (실제 환경에서는 별도의 테스트 데이터 사용 권장)
test_loaders = {task_id: val_loader}
# AWS 스팟 인스턴스 중단 감시 시작
logger.info("AWS 스팟 인스턴스 중단 감시 시작")
learner.spot_handler.start_monitoring(
model=learner.model,
optimizer=learner.optimizer,
scheduler=learner.scheduler,
scaler=learner.scaler,
trainer=learner
)
# 태스크 훈련 실행
logger.info(f"태스크 {task_id} ({args.mode} 모드) 훈련 시작")
learner.train_task(task_id, train_loader, val_loader, test_loaders)
# AWS 스팟 인스턴스 중단 감시 종료
learner.spot_handler.stop_monitoring()
except Exception as e:
logger.error(f"데이터 로드 또는 훈련 중 오류 발생: {str(e)}")
import traceback
logger.error(traceback.format_exc())
sys.exit(1)
# 최종 메트릭 계산 및 결과 저장
final_metrics = learner.metrics.compute_metrics()
results_path = os.path.join(config.get('log_dir', 'runs/continual_learning'), 'final_results.json')
os.makedirs(os.path.dirname(results_path), exist_ok=True)
with open(results_path, 'w', encoding='utf-8') as f:
json.dump({
'metrics': final_metrics,
'best_accuracies': learner.best_accuracies,
'config': config
}, f, indent=2)
logger.info(f"최종 결과가 저장되었습니다: {results_path}")
logger.info(f"최종 메트릭: {final_metrics}")
if __name__ == '__main__':
main()