[AI] Continual Learning 기술 스택 정리

seongyun·2025년 7월 6일

Hancom Project

목록 보기
7/12

1. ContinualLearner

  • 역할: 지속학습 전체 파이프라인의 핵심 컨트롤러.
  • 기능:
    • 베이스 모델 및 파인튜닝된 모델의 가중치 로드
    • 새로운 태스크(데이터셋)마다 학습/평가 루프 관리
    • EWC, MER, Replay Buffer 등 다양한 지속학습 전략을 선택적으로 적용
    • 체크포인트 저장/복원, 학습률 스케줄링, AMP(혼합정밀도) 관리 등 전체 학습 환경 통제

2. FeatureAdapter

  • 역할: 언어모델의 대용량 출력(예: logits, hidden states)을 분류기나 NCM 등 다운스트림 모듈이 기대하는 저차원 특징으로 변환
  • 기능:
    • (batch, seq_len, vocab_size) → (batch, feat_dim) 등으로 차원 변환
    • LayerNorm, Dropout 등으로 일반화 및 안정성 향상
    • 다양한 입력 형태(1D, 2D, 3D 텐서)에 대한 견고한 처리

3. NearestMeanClassifier (NMC)

  • 역할: 클래스별 평균 특징 벡터를 저장하고, 입력 특징과의 거리(유사도)로 분류를 수행하는 모듈
  • 기능:
    • 각 클래스별 중심(평균) 벡터 계산 및 업데이트
    • 입력 특징과의 코사인 유사도/유클리드 거리 기반 분류
    • 지속학습 시 새로운 클래스 추가 및 이전 클래스 유지 기능

4. FisherInformationCalculator

  • 역할: EWC(Elastic Weight Consolidation) 등에서 파라미터별 중요도를 계산하기 위한 Fisher Information Matrix 추정
  • 기능:
    • 현재 태스크 데이터로 파라미터별 Fisher 값을 계산
    • EWC 정규화 손실 계산에 활용

5. EWC (Elastic Weight Consolidation)

  • 역할: Catastrophic Forgetting(망각) 방지를 위해 이전 태스크에서 중요했던 파라미터가 급격히 변하지 않도록 제약을 거는 정규화 모듈
  • 기능:
    • 이전 태스크의 파라미터와 현재 파라미터의 차이에 Fisher 값을 곱해 페널티 부여
    • λ(람다) 하이퍼파라미터로 제약 강도 조절

6. MER (Memory Aware Synapses / Experience Replay)

  • 역할: 리플레이 버퍼에 과거 샘플을 저장하고, 새로운 태스크 학습 시 과거 샘플도 일부 재사용해 망각을 완화
  • 기능:
    • 버퍼 크기, 샘플링 전략(랜덤, 중요도 기반 등) 관리
    • 현재 태스크와 과거 태스크 데이터의 균형적 학습

7. SupervisedContrastiveLoss

  • 역할: 지도 대조학습 손실로, 같은 클래스는 가까이, 다른 클래스는 멀리 임베딩되도록 학습
  • 기능:
    • 온도 하이퍼파라미터로 유사도 분포 제어
    • NCE(Noise Contrastive Estimation) 등과 유사한 방식

8. SpotInstanceHandler

  • 역할: 클라우드 환경에서 AWS Spot 인스턴스 중단 신호 감지, 안전한 체크포인트 저장 및 복원
  • 기능:
    • 중단 감지 시 현재 상태를 즉시 체크포인트로 저장
    • 재시작 시 마지막 체크포인트에서 복원

9. DynamicBatchSizeManager

  • 역할: GPU 사용률을 모니터링하며, 배치 크기를 동적으로 조정해 자원 활용률 극대화 및 OOM 방지
  • 기능:
    • 목표 GPU 사용률에 맞춰 배치 크기 증감
    • OOM 발생 시 자동으로 배치 크기 축소

10. DataLoader/Preprocessor

  • 역할: 데이터셋 로딩, 전처리, 토크나이징, FIM(중간채움) 등 다양한 입력 형태 지원
  • 기능:
    • 훈련/검증/테스트 데이터 분리
    • FIM, Prefix, Suffix 등 다양한 코드 자동완성 포맷 지원

11. Tokenizer

  • 역할: 입력 텍스트/코드를 토큰 단위로 변환, 모델에 맞는 인덱스 시퀀스 생성
  • 기능:
    • 패딩, 트렁케이션, special token 처리
    • 다양한 언어 및 코드 포맷 지원

12. Checkpointer

  • 역할: 학습 중간중간 모델 가중치, 옵티마이저 상태 등을 저장/복원
  • 기능:
    • symlink 등으로 최신 체크포인트 관리
    • atomic 저장, 파일 무결성 검증

13. 학습/평가 루프 (train_epoch, eval_epoch 등)

  • 역할: 에폭 단위의 학습/평가 진행, 손실/정확도/코드 생성 품질 등 다양한 지표 기록

정리:

이 구조는 LLM의 지속학습 실험에 필요한 모든 핵심 기능(특징 추출, 망각 방지, 리플레이, 동적 배치, 체크포인트, 다양한 코드 자동완성 포맷 등)을 체계적으로 지원하도록 설계된 것이 특징이다.

각 클래스의 책임이 명확히 분리되어 있어, 실험 목적에 따라 독립적으로 교체·확장하기 용이하다.

현재까지의 개선사항

1. PEFT 어댑터 로딩 문제 해결

  • 두 형식 모두 지원하도록 로직 개선
  • 예외 처리 및 로깅 강화

2. 학습률 설정 문제 해결

  • 설정 파일에서 올바른 키로 학습률 로드
  • 문자열에서 float으로 자동 변환

3. 학습 효율 개선

  • 배치 크기 최적화 (1 → 4)
  • 로깅 빈도 최적화로 오버헤드 감소

최신 코드

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
CapaBoost + Continual Learning 통합 시스템
최신 연구 결과 기반 고성능 지속 학습 시스템
ICLR 2024 CapaBoost + MER + EWC + NCM + SCR

Created: June 2025
Authors: [James Kim]

실행 방법:
    python cc_train.py --config config.yaml
    또는
    ./run_training.sh
"""

import os
import sys
import json
import random
import logging
import argparse
import numpy as np
import time
import signal
import psutil
import GPUtil
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union, Any
from collections import defaultdict, deque
import datetime
from pathlib import Path
import warnings
from tqdm import tqdm
import yaml

# datasets 관련 import는 logger 초기화 후에 수행

# Transformers 관련 import 추가
# Transformers 관련 import 추가 (logger 초기화 후에 사용하기 위해 아래로 이동)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, Dataset, Subset
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR

# 데이터셋 관련 import
from datasets import load_dataset

# Transformers 관련 import 추가
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    try:
        from tensorboardX import SummaryWriter
    except ImportError:
        # TensorBoard를 사용할 수 없을 때 더미 클래스 제공
        class SummaryWriter:
            def __init__(self, *args, **kwargs):
                logger.warning("TensorBoard와 TensorboardX 둘 다 설치되지 않았습니다. 로깅이 비활성화됩니다.")
            
            def add_scalar(self, *args, **kwargs):
                pass
                
            def add_scalars(self, *args, **kwargs):
                pass
                
            def add_figure(self, *args, **kwargs):
                pass
                
            def close(self, *args, **kwargs):
                pass

# 중요도에 따라 경고 필터링
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, message="torch.cuda.*_rng_state")
warnings.filterwarnings("ignore", category=UserWarning, message="The parameter 'pretrained' is deprecated")

# 로깅 설정
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('cc_train.log')
    ]
)
logger = logging.getLogger(__name__)

# Transformers 관련 import 추가
try:
    import transformers
    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
    from peft import LoraConfig, get_peft_model, PeftModel
    TRANSFORMERS_AVAILABLE = True
except ImportError:
    TRANSFORMERS_AVAILABLE = False
    logger.warning("transformers 또는 PEFT 라이브러리가 설치되지 않았습니다. 실제 모델 로드가 불가능합니다.")

# Datasets 관련 import 추가
try:
    from datasets import load_dataset
    DATASETS_AVAILABLE = True
except ImportError:
    DATASETS_AVAILABLE = False
    logger.warning("datasets 라이브러리가 설치되지 않았습니다. 데이터셋 로드가 불가능합니다.")

# 전역 시드 고정 함수
def set_seed(seed: int):
    """완전한 재현 가능성을 보장하기 위한 시드 설정"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # 멀티 GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    logger.info(f"모든 랜덤 시드가 {seed}로 고정되었습니다.")

# AWS Spot Instance 중단 감지 클래스
class SpotInterruptionHandler:
    """AWS Spot Instance 중단 감지 및 대응 핸들러"""

    def __init__(self, checkpoint_dir: str = "checkpoints", checkpoint_freq: int = 100):
        self.checkpoint_dir = checkpoint_dir
        self.checkpoint_freq = checkpoint_freq
        self.interrupted = False
        self.last_checkpoint_time = time.time()
        self.setup_interruption_detector()
        logger.info("Spot Instance 중단 감지기가 초기화되었습니다.")

    def setup_interruption_detector(self):
        """중단 신호 감지 설정"""
        os.makedirs(self.checkpoint_dir, exist_ok=True)

        # SIGTERM 시그널 처리기 등록 (AWS Spot 중단 신호)
        signal.signal(signal.SIGTERM, self.interruption_handler)

        # 수동 체크 메서드를 위한 추가 설정
        self.metadata_url = "http://169.254.169.254/latest/meta-data/spot/instance-action"

    def interruption_handler(self, signum, frame):
        """중단 시그널 처리"""
        self.interrupted = True
        logger.warning("Spot Instance 중단 신호가 감지되었습니다!")
        # 여기서는 플래그만 설정하고, 실제 체크포인트 저장은 메인 학습 루프에서 처리

    def check_interruption(self):
        """EC2 메타데이터를 통한 중단 예정 확인"""
        try:
            import requests
            response = requests.get(self.metadata_url, timeout=0.1)
            if response.status_code == 200:
                self.interrupted = True
                logger.warning(f"Spot 중단 임박! 메타데이터: {response.json()}")
                return True
        except Exception:
            # 중단 예정이 아니거나 메타데이터 접근 실패 시 무시
            pass
        return self.interrupted

    def should_checkpoint(self, step: int) -> bool:
        """체크포인트 저장 필요 여부 확인"""
        # 중단 감지 시 즉시 체크포인트
        if self.check_interruption():
            return True

        # 정기적인 체크포인트
        if step % self.checkpoint_freq == 0:
            return True

        # 시간 기반 체크포인트 (10분마다)
        current_time = time.time()
        if current_time - self.last_checkpoint_time > 600:  # 10분
            self.last_checkpoint_time = current_time
            return True

        return False

    def start_monitoring(self, model=None, optimizer=None, scheduler=None, scaler=None, trainer=None):
        """스팟 인스턴스 중단 모니터링 시작 - test_environment.py와의 호환성을 위한 덧방화 함수"""
        # test_environment.py에서 스팟 인스턴스 모니터링을 처리하민로 필요한 파라미터만 저장
        if model is not None:
            self.model = model
        if optimizer is not None:
            self.optimizer = optimizer
        if scheduler is not None:
            self.scheduler = scheduler
        if scaler is not None:
            self.scaler = scaler
        if trainer is not None:
            self.trainer = trainer
        
        logger.info("스팟 인스턴스 체크포인트 상태 초기화 완료 (test_environment.py에서 모니터링 수행)")
        
    def stop_monitoring(self):
        """스팟 인스턴스 중단 모니터링 종료 - test_environment.py와의 호환성을 위한 덧방화 함수"""
        logger.info("스팟 인스턴스 모니터링 종료 요청됨 (test_environment.py 처리)")
    
    def save_checkpoint(self, model, optimizer, scheduler, scaler, trainer, step: int):
        """체크포인트 저장"""
        checkpoint_path = os.path.join(self.checkpoint_dir, f"checkpoint-{step}.pt")
        state_dict = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict() if scheduler else None,
            'scaler': scaler.state_dict() if scaler else None,
            'step': step,
            'trainer_state': trainer.get_state_dict()
        }
        torch.save(state_dict, checkpoint_path)

        # 최신 체크포인트 심볼릭 링크 업데이트
        latest_path = os.path.join(self.checkpoint_dir, "checkpoint-latest.pt")
        # 심볼릭 링크 확인 및 안전하게 제거
        if os.path.lexists(latest_path):  # lexists는 심볼릭 링크 자체가 있는지 확인
            try:
                os.unlink(latest_path)  # 심볼릭 링크를 안전하게 제거
            except Exception as e:
                logger.warning(f"심볼릭 링크 제거 중 오류: {e}")
                
        # 심볼릭 링크 생성 시 예외 처리
        try:
            os.symlink(checkpoint_path, latest_path)
        except FileExistsError:
            # 여전히 존재한다면 강제로 다시 시도
            logger.warning("심볼릭 링크가 여전히 존재함. 강제 삭제 후 재생성 시도...")
            try:
                os.remove(latest_path)  # 강제 삭제
                os.symlink(checkpoint_path, latest_path)  # 다시 생성
            except Exception as e:
                logger.error(f"심볼릭 링크 강제 재생성 실패: {e}")
                logger.info(f"심볼릭 링크 생성은 실패했지만, 체크포인트 {checkpoint_path}는 정상 저장됨")

        logger.info(f"체크포인트가 저장되었습니다: {checkpoint_path}")
        self.last_checkpoint_time = time.time()

    def load_checkpoint(self, checkpoint_path, model, optimizer, scheduler, scaler, trainer):
        """지정된 경로의 체크포인트 로드"""
        if not os.path.exists(checkpoint_path):
            logger.warning(f"체크포인트 파일이 존재하지 않습니다: {checkpoint_path}")
            return 0

        logger.info(f"체크포인트 로드 중: {checkpoint_path}")
        try:
            # 기본 체크포인트 형식(.pt 파일) 로드 시도
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            
            # 모델 로드
            if 'model' in checkpoint:
                model.load_state_dict(checkpoint['model'])
                
            # 옵티마이저 로드
            if 'optimizer' in checkpoint and optimizer is not None:
                optimizer.load_state_dict(checkpoint['optimizer'])
                
            # 스케줄러 로드
            if 'scheduler' in checkpoint and scheduler is not None and checkpoint['scheduler'] is not None:
                scheduler.load_state_dict(checkpoint['scheduler'])
                
            # 스케일러 로드
            if 'scaler' in checkpoint and scaler is not None and checkpoint['scaler'] is not None:
                scaler.load_state_dict(checkpoint['scaler'])
                
            # 트레이너 상태 로드
            if 'trainer_state' in checkpoint and trainer is not None:
                trainer.load_state_dict(checkpoint['trainer_state'])
            
            step = checkpoint.get('step', 0)
            logger.info(f"체크포인트 로드 완료. 스텝 {step}부터 재개합니다.")
            return step
            
        except Exception as e:
            # 디렉토리인 경우 (최종 모델 디렉토리 로드 시도)
            if os.path.isdir(checkpoint_path):
                logger.info(f"체크포인트 경로가 디렉토리입니다: {checkpoint_path}")
                try:
                    # 허깅페이스 모델 디렉토리에서 로드 시도
                    # PEFT 모델이 이미 적용된 경우 적절한 방식으로 처리
                    if hasattr(model, 'pretrained_model') or hasattr(model, 'base_model'):
                        # 이미 PEFT 모델인 경우
                        from peft import PeftModel, PeftConfig
                        logger.info(f"PEFT 모델에 대한 어댑터 로드 시도: {checkpoint_path}")
                        
                        # 기존 모델에서 기본 모델 추출
                        if hasattr(model, 'pretrained_model'):
                            base_model = model.pretrained_model
                        else:
                            base_model = model.base_model if hasattr(model, 'base_model') else model
                        
                        # 1. 기존 어댑터가 있는 경우 완전히 제거 (PEFT 0.4.0 방식)
                        if hasattr(model, 'unload'):
                            try:
                                model.unload()
                                logger.info("기존 PEFT 어댑터 제거 완료")
                            except Exception as e:
                                logger.warning(f"PEFT 어댑터 제거 중 오류(무시함): {e}")
                        
                        # 2. 어댑터 구성 파일 확인
                        adapter_config_path = os.path.join(checkpoint_path, "adapter_config.json")
                        if not os.path.exists(adapter_config_path):
                            logger.warning(f"adapter_config.json 파일이 없어 일반 모델로 로드합니다: {checkpoint_path}")
                            model = base_model
                            model.load_state_dict(torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"), map_location="cpu"), strict=False)
                            return model, optimizer, scheduler, scaler, trainer
                            
                        # 3. 어댑터 구성 정보 로드
                        try:
                            # PEFT 호환성 문제 해결을 위한 수정
                            # _prepare_model_for_peft_adapter 함수를 사용하지 않고 직접 어댑터 적용
                            
                            # 어댑터 구성 파일 로드
                            peft_config = PeftConfig.from_pretrained(checkpoint_path)
                            logger.info(f"어댑터 구성 로드 완료: {peft_config.peft_type}, target_modules={getattr(peft_config, 'target_modules', 'N/A')}")
                            
                            # 4. 기본 모델을 PEFT 어댑터에 맞게 준비 (가장 중요한 단계)
                            logger.info("기본 모델을 PEFT 어댑터에 맞게 준비 중...")
                            # PEFT 0.4.0에서는 이 함수가 없으므로 PeftModel이 자동으로 처리하도록 함
                            
                            # 5. 체크포인트에서 가중치만 로드 (두 형식 지원: .bin 및 .safetensors)
                            bin_weights_path = os.path.join(checkpoint_path, "adapter_model.bin")
                            safetensors_weights_path = os.path.join(checkpoint_path, "adapter_model.safetensors")
                            
                            # 6. 새 PeftModel 생성 (PEFT 0.4.0 방식)
                            model = PeftModel(base_model, peft_config, adapter_name="default")
                            
                            # 7. 가중치 파일 여부 확인 및 로드
                            weights_loaded = False
                            
                            # .bin 파일 확인
                            if os.path.exists(bin_weights_path):
                                logger.info(f"어댓터 가중치 (.bin) 로드 중: {bin_weights_path}")
                                adapter_weights = torch.load(bin_weights_path, map_location="cpu")
                                model.load_state_dict(adapter_weights, strict=False)
                                logger.info("어댓터 가중치 (.bin) 로드 완료")
                                weights_loaded = True
                            
                            # .safetensors 파일 확인
                            elif os.path.exists(safetensors_weights_path):
                                try:
                                    from safetensors import safe_open
                                    from safetensors.torch import load_file
                                    
                                    logger.info(f"어댓터 가중치 (.safetensors) 로드 중: {safetensors_weights_path}")
                                    # safetensors 파일에서 가중치 로드
                                    adapter_weights = load_file(safetensors_weights_path)
                                    model.load_state_dict(adapter_weights, strict=False)
                                    logger.info("어댓터 가중치 (.safetensors) 로드 완료")
                                    weights_loaded = True
                                except Exception as e:
                                    logger.error(f"safetensors 파일 로드 오류: {str(e)}")
                            
                            # 가중치 파일을 찾지 못한 경우
                            if not weights_loaded:
                                logger.warning(f"어댓터 가중치 파일을 찾을 수 없음: {bin_weights_path} 또는 {safetensors_weights_path}")
                                logger.warning("새 어댓터로 초기화하여 계속합니다.")
                            
                            # 8. 어댓터 활성화 확인 및 설정
                            if hasattr(model, 'active_adapter') and not model.active_adapter:
                                model.set_adapter("default")
                                logger.info("어댓터 활성화됨: default")
                        except Exception as e:
                            logger.error(f"PEFT 어댑터 로드 중 오류 발생: {str(e)}")
                            # 오류 발생 시 기본 모델 사용
                            model = base_model
                    else:
                        # 일반 모델인 경우
                        model.from_pretrained(checkpoint_path)
                    
                    logger.info(f"모델이 로드되었습니다: {checkpoint_path}")
                    return 0  # 디렉토리에서 로드하면 스텝을 0으로 초기화
                except Exception as inner_e:
                    logger.error(f"모델 로드 중 오류 발생: {inner_e}")
            else:
                logger.error(f"체크포인트 로드 중 오류 발생: {e}")
            
            return 0
    
    def load_latest_checkpoint(self, model, optimizer, scheduler, scaler, trainer):
        """최신 체크포인트 로드"""
        latest_path = os.path.join(self.checkpoint_dir, "checkpoint-latest.pt")
        if not os.path.exists(latest_path):
            logger.info("로드할 체크포인트가 없습니다.")
            return 0  # 시작 스텝

        logger.info(f"체크포인트를 로드합니다: {latest_path}")
        checkpoint = torch.load(latest_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if scheduler and checkpoint['scheduler']:
            scheduler.load_state_dict(checkpoint['scheduler'])
        if scaler and checkpoint['scaler']:
            scaler.load_state_dict(checkpoint['scaler'])

        if 'trainer_state' in checkpoint:
            trainer.load_state_dict(checkpoint['trainer_state'])

        step = checkpoint['step']
        logger.info(f"체크포인트 로드 완료, 스텝 {step}부터 재개합니다.")
        return step

# CapaBoost 구현 클래스
class CapaBoostMaskGenerator:
    """
    CapaBoost 정적 랜덤 마스크 생성기
    ICLR 2024 논문의 완전한 구현

    참고 논문: "Increasing Model Capacity for Free: A Simple Strategy for Parameter Efficient Fine-tuning"
    """
    def __init__(self, sparsity: float = 0.6, num_masks: int = 4, seed: int = 42):
        """
        Args:
            sparsity: 마스크의 스파시티 (0.0-1.0, 0.6이 최적)
            num_masks: 생성할 마스크 수 (4가 최적)
            seed: 결정론적 마스크 생성을 위한 시드
        """
        self.sparsity = sparsity
        self.num_masks = num_masks
        self.seed = seed
        self.masks_cache = {}
        logger.info(f"CapaBoost 초기화: 스파시티={sparsity}, 마스크 수={num_masks}, 시드={seed}")

    def generate_static_masks(self, shape: Tuple[int, int], layer_name: str) -> List[torch.Tensor]:
        """
        정적이고 결정론적인 랜덤 마스크 생성
        논문의 Equation (3) 구현
        """
        cache_key = f"{layer_name}_{shape}_{self.sparsity}_{self.num_masks}"
        if cache_key in self.masks_cache:
            return self.masks_cache[cache_key]

        # 시드를 이용한 결정론적 마스크 생성
        masks = []
        for i in range(self.num_masks):
            torch.manual_seed(self.seed + hash(layer_name) + i)
            # 베르누이 분포를 이용한 스파스 마스크 생성
            mask = torch.bernoulli(torch.full(shape, 1 - self.sparsity)).bool()
            masks.append(mask)

        self.masks_cache[cache_key] = masks
        return masks

    def apply_capaboost(self, weight: torch.Tensor, layer_name: str) -> torch.Tensor:
        """
        CapaBoost 메커니즘 적용
        z = Σ(w ⊙ mi)x + b 구현
        """
        masks = self.generate_static_masks(weight.shape, layer_name)

        # 마스킹된 가중치들의 합
        masked_weights = []
        for mask in masks:
            masked_weight = weight * mask.float().to(weight.device)
            masked_weights.append(masked_weight)

        # 효과적인 랭크 증가를 위한 가중치 합산
        effective_weight = sum(masked_weights)
        return effective_weight

# CapaBoost 적용을 위한 Linear 래퍼 클래스
class CapaBoostLinear(nn.Module):
    """CapaBoost를 적용한 Linear 레이어"""

    def __init__(self, in_features: int, out_features: int, bias: bool = True, 
                 mask_generator: CapaBoostMaskGenerator = None, layer_name: str = None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(in_features, out_features, bias)
        self.mask_generator = mask_generator
        self.layer_name = layer_name or f"linear_{id(self)}"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.mask_generator is not None:
            effective_weight = self.mask_generator.apply_capaboost(self.linear.weight, self.layer_name)
            return F.linear(x, effective_weight, self.linear.bias)
        else:
            return self.linear(x)

# Fisher Information Matrix 계산 클래스
class FisherInformationComputer:
    """
    정확한 Fisher Information Matrix 계산
    EWC와 LwF에서 사용
    """
    def __init__(self, mode: str = "empirical"):
        """
        Args:
            mode: Fisher 계산 모드 ('exact', 'empirical', 'diagonal')
        """
        self.mode = mode
        logger.info(f"Fisher Information 계산기 초기화: 모드={mode}")

    def compute_fisher_diagonal(self, model: nn.Module, dataloader, 
                              criterion, num_samples: int = 1000) -> Dict[str, torch.Tensor]:
        """
        대각 Fisher Information Matrix 계산
        F_i,i = E[(∂log p(y|x,θ)/∂θ_i)²] 구현
        """
        fisher = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                fisher[name] = torch.zeros_like(param)

        model.eval()
        samples_processed = 0

        with torch.no_grad():
            for data, target in tqdm(dataloader, desc="Fisher 계산 중"):
                if samples_processed >= num_samples:
                    break

                data, target = data.cuda(), target.cuda()
                batch_size = data.size(0)
                samples_processed += batch_size

                # 모드별 Fisher 계산
                if self.mode == "exact":
                    # 정확한 Fisher Information 계산
                    output = model(data)
                    log_probs = F.log_softmax(output, dim=1)

                    for class_idx in range(output.size(1)):
                        model.zero_grad()
                        class_log_prob = log_probs[:, class_idx].mean()
                        class_log_prob.backward(retain_graph=(class_idx < output.size(1) - 1))

                        for name, param in model.named_parameters():
                            if param.grad is not None:
                                fisher[name] += param.grad.pow(2) * batch_size / num_samples

                elif self.mode == "empirical":
                    # 경험적 Fisher Information (실제 레이블 사용)
                    for i in range(batch_size):
                        model.zero_grad()
                        output = model(data[i:i+1])
                        loss = criterion(output, target[i:i+1])
                        loss.backward()

                        for name, param in model.named_parameters():
                            if param.grad is not None:
                                fisher[name] += param.grad.pow(2) / num_samples

                elif self.mode == "diagonal":
                    # 대각 근사 Fisher (메모리 효율적)
                    output = model(data)
                    log_probs = F.log_softmax(output, dim=1)

                    # 실제 타겟에 대한 로그 확률의 그래디언트
                    model.zero_grad()
                    target_log_probs = log_probs[torch.arange(batch_size), target]
                    target_log_probs.mean().backward()

                    for name, param in model.named_parameters():
                        if param.grad is not None:
                            fisher[name] += param.grad.pow(2) * batch_size / num_samples

        model.train()
        return fisher

# Meta-Experience Replay 구현
class MetaExperienceReplay:
    """
    Meta-Experience Replay 완전 구현
    Stanford NLP의 원본 논문 기반
    """
    def __init__(self, buffer_size: int = 5120, beta: float = 0.1, 
                 gamma: float = 0.1, s: int = 10):
        """
        Args:
            buffer_size: 메모리 버퍼 크기
            beta: 배치 내 업데이트 강도
            gamma: 배치 간 업데이트 강도
            s: Reptile 스텝 수
        """
        self.buffer_size = buffer_size
        self.buffer = []
        self.beta = beta
        self.gamma = gamma
        self.s = s
        self.class_counter = defaultdict(int)
        logger.info(f"MER 초기화: 버퍼 크기={buffer_size}, β={beta}, γ={gamma}, s={s}")

    def reservoir_sampling(self, x: torch.Tensor, y: torch.Tensor):
        """저장소 샘플링을 이용한 메모리 버퍼 관리"""
        for i in range(x.size(0)):
            sample = (x[i].clone(), y[i].clone())

            if len(self.buffer) < self.buffer_size:
                self.buffer.append(sample)
                self.class_counter[y[i].item()] += 1
            else:
                # 클래스 밸런스를 유지하기 위한 선택적 대체
                if self.should_replace(y[i].item()):
                    # 같은 클래스의 샘플 대체
                    indices = [j for j, (_, label) in enumerate(self.buffer) if label == y[i]]
                    if indices:
                        idx = random.choice(indices)
                        self.buffer[idx] = sample

    def should_replace(self, class_label: int) -> bool:
        """클래스 밸런스를 위한 대체 결정"""
        # 적게 등장한 클래스는 대체 확률 높임
        counts = list(self.class_counter.values())
        if not counts:
            return True

        avg_count = sum(counts) / len(counts)
        current_count = self.class_counter[class_label]

        # 평균보다 적게 등장한 클래스는 대체할 확률 높임
        if current_count < avg_count:
            return random.random() < 0.8
        else:
            return random.random() < 0.2

    def sample_batch(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """버퍼에서 배치 샘플링"""
        if len(self.buffer) == 0:
            return None, None

        indices = np.random.choice(len(self.buffer), 
                                 min(batch_size, len(self.buffer)), 
                                 replace=False)

        x_batch = torch.stack([self.buffer[i][0] for i in indices])
        y_batch = torch.stack([self.buffer[i][1] for i in indices])

        return x_batch, y_batch

    def meta_update(self, model: nn.Module, optimizer, criterion, 
                   current_batch: Tuple[torch.Tensor, torch.Tensor],
                   memory_batch: Tuple[torch.Tensor, torch.Tensor]) -> float:
        """
        MER 메타 업데이트 수행
        Reptile 알고리즘 기반
        """
        x_curr, y_curr = current_batch
        x_mem, y_mem = memory_batch

        # 현재 파라미터 저장
        old_params = {name: param.clone() for name, param in model.named_parameters()}

        # Within-batch Reptile update
        for _ in range(self.s):
            # 현재 배치 업데이트
            optimizer.zero_grad()
            output_curr = model(x_curr)
            loss_curr = criterion(output_curr, y_curr)
            loss_curr.backward()
            optimizer.step()

            if x_mem is not None:
                # 메모리 배치 업데이트  
                inner_params = {name: param.clone() for name, param in model.named_parameters()}

                optimizer.zero_grad()
                output_mem = model(x_mem)
                loss_mem = criterion(output_mem, y_mem)
                loss_mem.backward()
                optimizer.step()

                # Within-batch interpolation
                for name, param in model.named_parameters():
                    param.data = inner_params[name] + self.beta * (param.data - inner_params[name])

        # Across-batch Reptile update (메타 업데이트)
        if x_mem is not None:
            for name, param in model.named_parameters():
                param.data = old_params[name] + self.gamma * (param.data - old_params[name])

        return loss_curr.item()

    def get_state_dict(self):
        """상태 저장을 위한 딕셔너리 반환"""
        return {
            'buffer': self.buffer,
            'class_counter': dict(self.class_counter),
            'buffer_size': self.buffer_size,
            'beta': self.beta,
            'gamma': self.gamma,
            's': self.s
        }

    def load_state_dict(self, state_dict):
        """저장된 상태 로드"""
        self.buffer = state_dict['buffer']
        self.class_counter = defaultdict(int, state_dict['class_counter'])
        self.buffer_size = state_dict['buffer_size']
        self.beta = state_dict['beta']
        self.gamma = state_dict['gamma']
        self.s = state_dict['s']

# 다이나믹 배치 사이즈 관리 클래스
class DynamicBatchSizeManager:
    """
    메모리 사용량에 따라 배치 크기를 동적으로 조정
    OOM 오류 방지 및 최적 성능 보장
    """
    def __init__(self, initial_batch_size: int, 
                 target_gpu_util: float = 0.85,
                 max_batch_size: int = 128,
                 min_batch_size: int = 1):
        self.batch_size = initial_batch_size
        self.target_gpu_util = target_gpu_util
        self.max_batch_size = max_batch_size
        self.min_batch_size = min_batch_size
        self.adjustment_counter = 0
        self.history = deque(maxlen=10)  # 최근 조정 내역
        logger.info(f"동적 배치 사이즈 관리자 초기화: 시작={initial_batch_size}, 목표 GPU 사용률={target_gpu_util}")

    def update(self, oom_occurred: bool = False) -> int:
        """
        현재 GPU 메모리 상태에 따라 배치 크기 업데이트

        Args:
            oom_occurred: OOM 오류 발생 여부

        Returns:
            새로운 배치 크기
        """
        if oom_occurred:
            # OOM 발생 시 배치 크기 즉시 감소
            new_batch_size = max(self.min_batch_size, int(self.batch_size * 0.5))
            self.history.append(("OOM", self.batch_size, new_batch_size))
            self.batch_size = new_batch_size
            logger.warning(f"OOM 감지, 배치 크기 감소: {self.batch_size}")
            return self.batch_size

        # 5번마다 배치 크기 조정 (너무 자주 조정하지 않도록)
        self.adjustment_counter += 1
        if self.adjustment_counter % 5 != 0:
            return self.batch_size

        # GPU 메모리 사용량 확인
        try:
            gpus = GPUtil.getGPUs()
            if not gpus:
                return self.batch_size

            gpu = gpus[0]  # 첫 번째 GPU 사용 가정
            mem_util = gpu.memoryUtil

            # 메모리 사용량에 따른 배치 크기 조정
            if mem_util > self.target_gpu_util + 0.05:
                # 메모리 사용량이 너무 높음, 배치 크기 감소
                new_batch_size = max(self.min_batch_size, int(self.batch_size * 0.9))
                action = "감소"
            elif mem_util < self.target_gpu_util - 0.1 and self.batch_size < self.max_batch_size:
                # 메모리 사용량이 낮음, 배치 크기 증가
                new_batch_size = min(self.max_batch_size, int(self.batch_size * 1.1))
                action = "증가"
            else:
                # 적정 범위 내, 유지
                return self.batch_size

            if new_batch_size != self.batch_size:
                self.history.append((action, self.batch_size, new_batch_size))
                self.batch_size = new_batch_size
                logger.info(f"배치 크기 {action}: {self.batch_size} (GPU 메모리 사용률: {mem_util:.2f})")

        except Exception as e:
            logger.warning(f"배치 크기 업데이트 중 오류 발생: {e}")

        return self.batch_size

# 적응적 학습률 스케줄러
class AdaptiveLRScheduler:
    """
    선형 감소 + 웜업 + 종료 시 급격한 감소를 포함한 
    적응적 학습률 스케줄링
    """
    def __init__(self, optimizer, total_steps: int, 
                 warmup_steps: int = 0,
                 min_lr_ratio: float = 0.1,
                 end_lr_ratio: float = 0.01):
        """
        Args:
            optimizer: 학습률을 조정할 옵티마이저
            total_steps: 총 훈련 스텝 수
            warmup_steps: 웜업 스텝 수
            min_lr_ratio: 최소 학습률 비율
            end_lr_ratio: 종료 시 학습률 비율
        """
        self.optimizer = optimizer
        self.total_steps = total_steps
        self.warmup_steps = warmup_steps
        self.min_lr_ratio = min_lr_ratio
        self.end_lr_ratio = end_lr_ratio
        self.initial_lr = [group['lr'] for group in optimizer.param_groups]

        self.scheduler = LambdaLR(optimizer, self.lr_lambda)
        logger.info(f"적응적 학습률 스케줄러 초기화: 웜업={warmup_steps}, 총 스텝={total_steps}")

    def lr_lambda(self, current_step: int):
        """학습률 계산 함수"""
        if current_step < self.warmup_steps:
            # 웜업 구간: 0 -> 1
            return float(current_step) / float(max(1, self.warmup_steps))

        # 훈련 막바지 10%에서는 더 급격히 감소
        if current_step > 0.9 * self.total_steps:
            progress = (current_step - 0.9 * self.total_steps) / (0.1 * self.total_steps)
            decay = max(self.end_lr_ratio, (1.0 - progress) * (1.0 - 0.9) + 0.9 * self.end_lr_ratio)
            return decay

        # 메인 구간: 선형 감소 (1 -> min_lr_ratio)
        progress = (current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
        decay = max(self.min_lr_ratio, 1.0 - progress)
        return decay

    def step(self):
        self.scheduler.step()

    def get_last_lr(self):
        return self.scheduler.get_last_lr()

    def state_dict(self):
        return self.scheduler.state_dict()

    def load_state_dict(self, state_dict):
        self.scheduler.load_state_dict(state_dict)

# Feature Adapter 구현
class FeatureAdapter(nn.Module):
    """
    언어 모델의 큰 출력 차원을 분류기가 기대하는 작은 차원으로 변환하는 어댑터
    행렬곱 차원 불일치 문제를 해결하기 위한 중간 변환 레이어
    """
    
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = 2048):
        """
        Args:
            input_dim: 입력 차원 (언어 모델의 vocab_size 또는 출력 차원)
            output_dim: 출력 차원 (분류기가 기대하는 특징 차원)  
            hidden_dim: 중간 은닉층 차원
        """
        super().__init__()
        
        # 입력 차원이 매우 큰 경우 (vocab_size 등) 단계적 차원 축소
        if input_dim > 10000:
            self.layers = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.LayerNorm(hidden_dim // 2),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim // 2, output_dim)
            )
        else:
            # 입력 차원이 작은 경우 단순한 변환
            self.layers = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim, output_dim)
            )
            
        logger.info(f"FeatureAdapter 초기화: {input_dim}{output_dim} (hidden: {hidden_dim})")
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        특징 변환 수행
        
        Args:
            x: 입력 텐서 - 다양한 형태의 모델 출력을 처리
        
        Returns:
            변환된 특징 텐서 (batch_size, output_dim)
        """
        # 입력 텐서 형태 정규화
        if len(x.shape) == 3:  # (batch_size, seq_len, hidden_dim)
            # 시퀀스의 마지막 토큰 사용 (언어 모델의 일반적 패턴)
            x = x[:, -1, :]
        elif len(x.shape) > 3:
            # 더 고차원인 경우 배치 차원을 제외한 모든 차원을 평탄화
            batch_size = x.size(0)
            x = x.view(batch_size, -1)
        elif len(x.shape) == 1:
            # 1차원인 경우 배치 차원 추가
            x = x.unsqueeze(0)
        
        # 차원 변환 적용
        return self.layers(x)

# Nearest Mean Classifier 구현
class NearestMeanClassifier(nn.Module):
    """
    Stability Gap 완화를 위한 Nearest Mean Classifier
    Supervised Contrastive Replay와 결합 시 최고 성능
    """
    def __init__(self, num_classes: int, feat_dim: int):
        """
        Args:
            num_classes: 클래스 수
            feat_dim: 특징 벡터 차원
        """
        super().__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.register_buffer('class_means', torch.zeros(num_classes, feat_dim))
        self.register_buffer('class_counts', torch.zeros(num_classes))
        logger.info(f"Nearest Mean Classifier 초기화: 클래스={num_classes}, 차원={feat_dim}")

    def update_class_means(self, features: torch.Tensor, labels: torch.Tensor):
        """클래스별 평균 특징 벡터 업데이트"""
        for c in range(self.num_classes):
            idx = (labels == c)
            if idx.sum() > 0:
                class_feats = features[idx]
                self.class_means[c] = (self.class_means[c] * self.class_counts[c] + class_feats.sum(0)) / (self.class_counts[c] + idx.sum())
                self.class_counts[c] += idx.sum()

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """
        가장 가까운 클래스 평균으로 분류

        Args:
            features: 모델에서 추출한 특징 벡터

        Returns:
            각 클래스의 로짓 값
        """
        # L2 정규화
        norm_features = F.normalize(features, p=2, dim=1)
        norm_class_means = F.normalize(self.class_means, p=2, dim=1)

        # 코사인 유사도 계산
        logits = torch.matmul(norm_features, norm_class_means.t())

        # 온도 스케일링 (선택적)
        # logits = logits / 0.07  # temperature

        return logits

    def get_state_dict(self):
        """상태 저장"""
        return {
            'class_means': self.class_means.clone(),
            'class_counts': self.class_counts.clone(),
            'num_classes': self.num_classes,
            'feat_dim': self.feat_dim
        }

    def load_state_dict(self, state_dict, strict=True):
        """상태 로드"""
        self.num_classes = state_dict['num_classes']
        self.feat_dim = state_dict['feat_dim']
        self.class_means = state_dict['class_means'].clone()
        self.class_counts = state_dict['class_counts'].clone()

# Supervised Contrastive Replay 손실
class SupervisedContrastiveLoss(nn.Module):
    """
    NCM과 함께 사용하는 지도 대조 손실
    같은 클래스의 샘플들은 가까게, 다른 클래스의 샘플들은 멀게 배치
    """
    def __init__(self, temperature: float = 0.07):
        """
        Args:
            temperature: 대조 손실 온도 파라미터
        """
        super().__init__()
        self.temperature = temperature
        logger.info(f"지도 대조 손실 초기화: 온도={temperature}")

    def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        지도 대조 손실 계산

        Args:
            features: 정규화된 특징 벡터 (batch_size, feat_dim)
            labels: 클래스 레이블 (batch_size)

        Returns:
            대조 손실 값
        """
        # L2 정규화
        features = F.normalize(features, p=2, dim=1)

        # 유사도 행렬 계산
        similarity_matrix = torch.matmul(features, features.t()) / self.temperature

        # 마스크 생성: 같은 클래스 = 1, 다른 클래스 = 0
        batch_size = features.size(0)
        labels_matrix = labels.expand(batch_size, batch_size).eq(labels.expand(batch_size, batch_size).t())

        # 자기 자신은 제외
        labels_matrix.fill_diagonal_(False)

        # 양의 쌍이 있는지 확인
        pos_pairs = labels_matrix.sum(dim=1) > 0

        if not pos_pairs.any():
            return torch.tensor(0.0, device=features.device)

        # 마스크가 있는 샘플만 선택
        labels_matrix = labels_matrix[pos_pairs]
        similarity_matrix = similarity_matrix[pos_pairs]

        # 로그-소프트맥스와 마스크 적용
        logits_max, _ = torch.max(similarity_matrix, dim=1, keepdim=True)
        logits = similarity_matrix - logits_max.detach()  # 수치 안정성

        exp_logits = torch.exp(logits)
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # 양의 쌍의 로그 확률 평균
        mean_log_prob_pos = (labels_matrix * log_prob).sum(1) / labels_matrix.sum(1)

        # 손실 계산
        loss = -mean_log_prob_pos.mean()
        return loss

# EWC (Elastic Weight Consolidation) 구현
class ElasticWeightConsolidation:
    """
    Elastic Weight Consolidation
    이전 태스크의 중요한 가중치를 보호하여 Catastrophic Forgetting 완화
    """
    def __init__(self, model: nn.Module, fisher_computer: FisherInformationComputer,
                 lambda_ewc: float = 5000.0, normalize_fisher: bool = True):
        """
        Args:
            model: 모델
            fisher_computer: Fisher Information 계산기
            lambda_ewc: EWC 강도 계수
            normalize_fisher: Fisher 정규화 여부
        """
        self.model = model
        self.fisher_computer = fisher_computer
        self.lambda_ewc = lambda_ewc
        self.normalize_fisher = normalize_fisher

        self.fisher_dict = {}  # task_id -> fisher
        self.optpar_dict = {}  # task_id -> optimal parameters

        logger.info(f"EWC 초기화: λ={lambda_ewc}, 정규화={normalize_fisher}")

    def register_task(self, task_id: int, dataloader, criterion):
        """
        새 태스크 등록 및 Fisher Information 계산

        Args:
            task_id: 태스크 ID
            dataloader: 태스크 데이터 로더
            criterion: 손실 함수
        """
        logger.info(f"태스크 {task_id} Fisher Information 계산 중...")
        # Fisher Information 계산
        fisher = self.fisher_computer.compute_fisher_diagonal(
            self.model, dataloader, criterion
        )

        # Fisher 정규화 (선택적)
        if self.normalize_fisher:
            for name, f in fisher.items():
                fisher[name] = f / f.sum()

        # 현재 파라미터 저장
        optpar = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                optpar[name] = param.data.clone()

        # 태스크 정보 저장
        self.fisher_dict[task_id] = fisher
        self.optpar_dict[task_id] = optpar

        logger.info(f"태스크 {task_id} 등록 완료")

    def compute_ewc_loss(self) -> torch.Tensor:
        """
        EWC 정규화 손실 계산
        이전 태스크의 중요 파라미터 변화에 페널티 부여

        Returns:
            EWC 손실
        """
        loss = torch.tensor(0., device=next(self.model.parameters()).device)

        # 등록된 태스크가 없으면 손실 없음
        if not self.fisher_dict:
            return loss

        # 각 태스크에 대한 EWC 손실 계산
        for task_id, fisher in self.fisher_dict.items():
            optpar = self.optpar_dict[task_id]

            for name, param in self.model.named_parameters():
                if name in fisher and param.requires_grad:
                    try:
                        # 텐서 처리를 안전하게 수행
                        # 가능한 경우 먼저 머신 에러 체크
                        if fisher[name].shape != param.shape or optpar[name].shape != param.shape:
                            logger.warning(f"EWC 손실 계산 시 텐서 형태 불일치 무시: {name}, {fisher[name].shape}, {param.shape}, {optpar[name].shape}")
                            continue
                            
                        # 두 텐서의 차이
                        diff = param - optpar[name]
                        # 제곱 계산
                        squared_diff = diff.pow(2)
                        # Fisher로 가중치 계산
                        weighted_squared_diff = fisher[name] * squared_diff
                        # 합계 계산
                        param_loss = weighted_squared_diff.sum()
                        # 확인 후 추가
                        if torch.isfinite(param_loss).all():  # NaN 확인
                            loss += param_loss
                        else:
                            logger.warning(f"EWC 손실에서 NaN/Inf 발견, 무시함: {name}")
                    except Exception as e:
                        logger.warning(f"EWC 손실 계산 중 오류 발생(무시): {name}, {str(e)}")
                        continue

        # 태스크 수로 정규화
        num_tasks = len(self.fisher_dict)
        if num_tasks > 0:
            loss = self.lambda_ewc * loss / num_tasks

        return loss

    def get_state_dict(self):
        """상태 저장"""
        return {
            'fisher_dict': self.fisher_dict,
            'optpar_dict': self.optpar_dict,
            'lambda_ewc': self.lambda_ewc,
            'normalize_fisher': self.normalize_fisher
        }

    def load_state_dict(self, state_dict):
        """상태 로드"""
        self.fisher_dict = state_dict['fisher_dict']
        self.optpar_dict = state_dict['optpar_dict']
        self.lambda_ewc = state_dict['lambda_ewc']
        self.normalize_fisher = state_dict['normalize_fisher']

# 지속 학습 메트릭 계산기
class ContinualLearningMetrics:
    """
    지속 학습 전용 평가 메트릭
    ACC, BWT, FWT, Stability Gap 등 포괄적 평가
    """
    def __init__(self):
        self.task_accuracies = defaultdict(list)
        self.stability_gaps = []
        self.task_names = []

    def update(self, task_id: int, test_results: Dict[int, float]):
        """
        태스크별 정확도 업데이트

        Args:
            task_id: 현재 학습 중인 태스크 ID
            test_results: 모든 이전 태스크에 대한 테스트 결과 {task_id: accuracy}
        """
        for test_task_id, accuracy in test_results.items():
            self.task_accuracies[test_task_id].append(accuracy)

    def record_stability_gap(self, task_id: int, accuracies: List[float]):
        """
        Stability Gap 기록

        Args:
            task_id: 태스크 ID
            accuracies: 학습 중 이전 태스크의 정확도 변화
        """
        if len(accuracies) < 3:  # 너무 적은 데이터
            return

        initial_acc = accuracies[0]
        min_acc = min(accuracies)
        final_acc = accuracies[-1]

        gap = initial_acc - min_acc
        recovery = final_acc - min_acc

        self.stability_gaps.append({
            'task_id': task_id,
            'gap': gap,
            'recovery': recovery,
            'recovery_ratio': recovery / gap if gap > 0 else 1.0
        })

    def compute_metrics(self) -> Dict[str, float]:
        """
        BWT, FWT, ACC, Stability Gap 계산

        Returns:
            계산된 메트릭들
        """
        num_tasks = len(self.task_accuracies)

        # Average Accuracy (ACC)
        final_accuracies = [self.task_accuracies[i][-1] for i in range(num_tasks)]
        acc = np.mean(final_accuracies)

        # Backward Transfer (BWT)
        bwt = 0
        for i in range(num_tasks - 1):
            bwt += self.task_accuracies[i][-1] - self.task_accuracies[i][i]
        bwt /= (num_tasks - 1) if num_tasks > 1 else 1

        # Forward Transfer (FWT) - 첫 번째 에포크 성능 기반
        fwt = 0
        baseline_acc = 0  # 무작위 성능 가정
        for i in range(1, num_tasks):
            if len(self.task_accuracies[i]) > 0:
                fwt += self.task_accuracies[i][0] - baseline_acc
        fwt /= (num_tasks - 1) if num_tasks > 1 else 1

        # Average Stability Gap
        avg_gap = 0
        avg_recovery_ratio = 0
        if self.stability_gaps:
            avg_gap = np.mean([gap['gap'] for gap in self.stability_gaps])
            avg_recovery_ratio = np.mean([gap['recovery_ratio'] for gap in self.stability_gaps])

        return {
            'ACC': acc,
            'BWT': bwt, 
            'FWT': fwt,
            'final_accuracies': final_accuracies,
            'stability_gap': avg_gap,
            'recovery_ratio': avg_recovery_ratio
        }

    def get_state_dict(self):
        """상태 저장"""
        return {
            'task_accuracies': dict(self.task_accuracies),
            'stability_gaps': self.stability_gaps,
            'task_names': self.task_names
        }

    def load_state_dict(self, state_dict):
        """저장된 상태 로드"""
        self.task_accuracies = defaultdict(list, state_dict['task_accuracies'])
        self.stability_gaps = state_dict['stability_gaps']
        self.task_names = state_dict['task_names']

# 메인 지속 학습 트레이너 클래스
class ContinualLearner:
    """
    CapaBoost + Continual Learning 통합 시스템
    높은 성공률과 안정성을 위한 모든 기법 통합
    """
    def __init__(self, config: Dict[str, Any]):
        """
        Args:
            config: 설정 파라미터
        """
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # 시드 설정
        set_seed(config.get('seed', 42))

        # 모델 구성
        self.build_model()

        # 핵심 컴포넌트 초기화
        self.setup_components()

        # 메트릭 및 로깅 설정
        self.metrics = ContinualLearningMetrics()
        self.writer = SummaryWriter(config.get('log_dir', 'runs/continual_learning'))

        # AWS Spot Instance 중단 핸들러
        self.spot_handler = SpotInterruptionHandler(
            checkpoint_dir=config.get('checkpoint_dir', 'checkpoints'),
            checkpoint_freq=config.get('checkpoint_freq', 100)
        )

        # 배치 크기 관리자
        self.batch_manager = DynamicBatchSizeManager(
            initial_batch_size=config.get('batch_size', 32),
            target_gpu_util=config.get('target_gpu_util', 0.85),
            max_batch_size=config.get('max_batch_size', 128)
        )

        # 학습 상태
        self.current_task_id = 0
        self.global_step = 0
        self.best_accuracies = {}

        logger.info("ContinualLearner 초기화 완료")

    def build_model(self):
        """모델 구성 - DeepSeek-Coder 모델 로드"""
        model_config = self.config.get('model', {})
        model_name = model_config.get('name', 'deepseek-coder')
        mode = self.config.get('mode', 'prompt')  # 학습 모드 확인
        
        # DeepSeek-Coder 모델 로딩 시도
        if TRANSFORMERS_AVAILABLE and model_name == 'deepseek-coder':
            try:
                # AWS 환경에 맞는 기본 모델 경로 설정 (절대경로)
                base_model = "/home/ubuntu/deepseek-coder/models/deepseek-coder-6.7b-instruct"  # 기본 모델
                
                # 홈 디렉토리 기반 모델 루트 경로 설정
                model_base_path = os.path.expanduser("~/deepseek-coder/models")
                
                # 모드별 모델 경로 확인
                if mode == 'complete':
                    # 코드 자동완성 모드
                    model_path = os.path.join(model_base_path, "autocomplete-finetuned/final_model")
                    logger.info(f"[자동완성] 모드 - 모델 경로: {model_path}")
                elif mode == 'prompt':
                    # 프롬프트 모드
                    model_path = os.path.join(model_base_path, "prompt-finetuned/final_model")
                    logger.info(f"[프롬프트] 모드 - 모델 경로: {model_path}")
                elif mode == 'comment':
                    # 주석 모드
                    model_path = os.path.join(model_base_path, "comment-finetuned/final_model")
                    logger.info(f"[주석] 모드 - 모델 경로: {model_path}")
                elif mode == 'error_fix':
                    # 오류 수정 모드
                    model_path = os.path.join(model_base_path, "error-fix-finetuned/final_model")
                    logger.info(f"[오류수정] 모드 - 모델 경로: {model_path}")
                
                else:
                    # 알 수 없는 모드일 경우 기본으로 prompt 모드 사용
                    model_path = os.path.join(model_base_path, "prompt-finetuned/final_model")
                    logger.info(f"알 수 없는 모드: {mode}, 프롬프트 모델 사용: {model_path}")
                
                # 경로가 없으면 기본 모델 사용 예정임을 로깅
                if not os.path.exists(model_path):
                    logger.warning(f"지정된 모델 경로가 존재하지 않습니다: {model_path}")
                    logger.warning(f"기본 모델 {base_model}를 사용합니다.")
                    use_base_model = True
                else:
                    use_base_model = False
                
                # 양자화 설정
                bnb_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.float16,
                    bnb_4bit_use_double_quant=True
                )
                
                try:
                    if use_base_model:
                        # 기본 모델 로드
                        logger.info(f"기본 모델 로드 시도: {base_model}")
                        self.tokenizer = AutoTokenizer.from_pretrained(base_model)
                        self.model = AutoModelForCausalLM.from_pretrained(
                            base_model,
                            quantization_config=bnb_config,
                            device_map="auto"
                        )
                        logger.info(f"기본 모델 로드 성공!")
                    else:
                        # 학습된 모델 로드
                        logger.info(f"학습된 모델 로드 시도: {model_path}")
                        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
                        self.model = AutoModelForCausalLM.from_pretrained(
                            model_path,
                            quantization_config=bnb_config,
                            device_map="auto"
                        )
                        logger.info(f"학습된 {mode} 모델 로드 성공!")
                except Exception as e:
                    # 오류 발생시 기본 모델로 돌아가기
                    logger.warning(f"모델 로드 오류: {str(e)}")
                    logger.info(f"기본 모델 로드 시도(콜백): {base_model}")
                    try:
                        self.tokenizer = AutoTokenizer.from_pretrained(base_model)
                        self.model = AutoModelForCausalLM.from_pretrained(
                            base_model,
                            quantization_config=bnb_config,
                            device_map="auto"
                        )
                        logger.info(f"기본 모델 로드 성공(fallback)!")
                    except Exception as nested_e:
                        logger.error(f"기본 모델 로드도 실패: {str(nested_e)}")
                        # 모델 로드가 완전히 실패한 경우 - dummy 모델 사용
                        self._create_dummy_model()
                        logger.warning("모든 모델 로드 시도 실패. 더미 모델로 최소 한의 실행을 보장합니다.")
                
                # LoRA 설정
                lora_config = LoraConfig(
                    r=16,
                    lora_alpha=32,
                    lora_dropout=0.1,
                    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
                    bias="none",
                    task_type="CAUSAL_LM"
                )
                
                # LoRA 적용
                self.model = get_peft_model(self.model, lora_config)
                
                # 모델 설정 저장
                self.feat_dim = 4096  # DeepSeek-Coder hidden size
                self.vocab_size = self.model.config.vocab_size  # 언어 모델 어휘 크기
                
                # 언어 모델 로그 정보
                logger.info(f"언어 모델 어휘 크기: {self.vocab_size}, 히든 차원: {self.feat_dim}")
                
                # 분류기와 FeatureAdapter 제거 - 언어 모델 내장 손실만 사용하도록 수정
                
                logger.info(f"DeepSeek-Coder 모델 로드 성공")
                
            except Exception as e:
                logger.error(f"DeepSeek-Coder 모델 로드 오류: {str(e)}")
                # 오류 발생시 더미 모델로 돌아감
                self._create_dummy_model()
        else:
            # Transformers 미설치 또는 지원되지 않는 모델일 경우 더미 모델 사용
            logger.warning(f"Transformers 라이브러리 미설치 또는 지원되지 않는 모델: {model_name}")
            self._create_dummy_model()
    
    def _create_dummy_model(self):
        """더미 모델 생성 (테스트용)"""
        logger.warning("더미 모델을 사용합니다. 실제 학습에는 적합하지 않습니다.")
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        ).to(self.device)
        
        self.feat_dim = 256
        self.vocab_size = 1000  # 더미 모델의 어휘 사이즈
        
        # 분류기와 Feature Adapter 제거
        # 언어 모델만 사용하고 내장 손실 함수 사용
        logger.info(f"더미 모델 생성 완료 - 히든 차원: {self.feat_dim}")


        # CapaBoost 적용
        if self.config.get('use_capaboost', True):
            self.apply_capaboost()

        # 그래디언트 체크포인팅 (메모리 최적화)
        if self.config.get('use_gradient_checkpointing', True):
            # Hugging Face 모델은 gradient_checkpointing_enable 메서드가 있지만 Sequential는 없음
            # 안전하게 메서드 존재 여부 확인 후 호출
            if hasattr(self.model, 'gradient_checkpointing_enable'):
                self.model.gradient_checkpointing_enable()
                logger.info("그래디언트 체크포인팅이 활성화되었습니다.")
            else:
                logger.warning("이 모델은 그래디언트 체크포인팅을 지원하지 않습니다.")

    def apply_capaboost(self):
        """CapaBoost를 모델의 선형 레이어에 적용"""
        capaboost_config = self.config.get('capaboost', {})
        self.mask_generator = CapaBoostMaskGenerator(
            sparsity=capaboost_config.get('sparsity', 0.6),
            num_masks=capaboost_config.get('num_masks', 4),
            seed=capaboost_config.get('seed', 42)
        )

        # 각 선형 레이어를 CapaBoost 레이어로 교체
        for name, module in list(self.model.named_children()):
            if isinstance(module, nn.Linear):
                # CapaBoost Linear로 교체
                capaboost_linear = CapaBoostLinear(
                    in_features=module.in_features,
                    out_features=module.out_features,
                    bias=module.bias is not None,
                    mask_generator=self.mask_generator,
                    layer_name=name
                )
                # 가중치 복사
                capaboost_linear.linear.weight.data.copy_(module.weight.data)
                if module.bias is not None:
                    capaboost_linear.linear.bias.data.copy_(module.bias.data)

                # 모듈 교체
                if isinstance(self.model, nn.Sequential):
                    self.model._modules[name] = capaboost_linear
                else:
                    setattr(self.model, name, capaboost_linear)

        logger.info("CapaBoost가 모든 선형 레이어에 적용되었습니다.")

    def setup_components(self):
        """핵심 컴포넌트 초기화"""
        # 옵티마이저 설정
        # config.yaml에서 learning_rate 키를 직접 가져옴 (문자열을 float로 변환)
        learning_rate = float(self.config.get('learning_rate', '2e-4'))
        weight_decay = float(self.config.get('weight_decay', '0.01'))
        logger.info(f"학습률 초기화: {learning_rate}, 가중치 감소율: {weight_decay}")
        
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay
        )

        # Fisher Information 계산기
        fisher_config = self.config.get('fisher', {})
        self.fisher_computer = FisherInformationComputer(
            mode=fisher_config.get('mode', 'empirical')
        )

        # EWC 설정
        ewc_config = self.config.get('ewc', {})
        self.ewc = ElasticWeightConsolidation(
            model=self.model,
            fisher_computer=self.fisher_computer,
            lambda_ewc=ewc_config.get('lambda', 5000.0),
            normalize_fisher=ewc_config.get('normalize', True)
        )

        # Meta-Experience Replay
        mer_config = self.config.get('mer', {})
        self.mer = MetaExperienceReplay(
            buffer_size=mer_config.get('buffer_size', 5120),
            beta=mer_config.get('beta', 0.1),
            gamma=mer_config.get('gamma', 0.1),
            s=mer_config.get('s', 10)
        )

        # 분류기 관련 손실 함수 제거 - 언어 모델 내장 손실만 사용
        logger.info("언어 모델링 내장 손실 함수 사용 설정")

        # 혼합 정밀도 훈련
        self.use_amp = self.config.get('use_amp', True) and torch.cuda.is_available()
        self.scaler = GradScaler() if self.use_amp else None
        if self.use_amp:
            logger.info("혼합 정밀도 훈련(AMP)이 활성화되었습니다.")

        # 학습률 스케줄러
        scheduler_config = self.config.get('scheduler', {})
        self.scheduler = AdaptiveLRScheduler(
            optimizer=self.optimizer,
            total_steps=scheduler_config.get('total_steps', 10000),
            warmup_steps=scheduler_config.get('warmup_steps', 100),
            min_lr_ratio=scheduler_config.get('min_lr_ratio', 0.1),
            end_lr_ratio=scheduler_config.get('end_lr_ratio', 0.01)
        )

    def train_task(self, task_id: int, train_loader: DataLoader, val_loader: DataLoader, test_loaders: Dict[int, DataLoader]):
        """
        단일 태스크에 대한 훈련 수행

        Args:
            task_id: 태스크 ID
            train_loader: 훈련 데이터 로더
            val_loader: 검증 데이터 로더
            test_loaders: 이전 태스크 테스트 데이터 로더 {task_id: loader}
        """
        self.current_task_id = task_id

        # 태스크 구성
        max_epochs = self.config.get('max_epochs', 10)
        patience = self.config.get('patience', 3)

        # 조기 종료 추적
        best_val_acc = 0.0
        patience_counter = 0

        # 태스크 훈련 시작
        logger.info(f"=== 태스크 {task_id} 훈련 시작 ===")

        for epoch in range(max_epochs):
            # 훈련 에포크
            train_loss, train_acc = self.train_epoch(
                train_loader=train_loader, 
                epoch=epoch,
                task_id=task_id
            )

            # 검증
            val_loss, val_acc = self.evaluate(val_loader)

            # 모든 이전 태스크에 대한 테스트
            test_results = {}
            for test_task_id, test_loader in test_loaders.items():
                _, test_acc = self.evaluate(test_loader)
                test_results[test_task_id] = test_acc

            # 메트릭 업데이트
            self.metrics.update(task_id, test_results)

            # 안정성 측정을 위해 이전 태스크 정확도 기록
            if task_id > 0 and 0 in test_results:
                self.metrics.record_stability_gap(
                    task_id, 
                    self.metrics.task_accuracies[0]
                )

            # 로깅
            metrics = self.metrics.compute_metrics()
            self.log_metrics(epoch, train_loss, train_acc, val_loss, val_acc, test_results, metrics)

            # 조기 종료 확인
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
                # 최적 모델 저장
                self.save_model(f"best_model_task_{task_id}.pt")
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    logger.info(f"조기 종료: {patience}에포크 동안 개선 없음")
                    break

        # 태스크 완료 후 EWC 등록
        self.ewc.register_task(task_id, train_loader, self.ce_loss)

        # 최종 성능 저장
        self.best_accuracies[task_id] = best_val_acc

        # 종합 메트릭 로깅
        final_metrics = self.metrics.compute_metrics()
        logger.info(f"=== 태스크 {task_id} 훈련 완료 ===")
        logger.info(f"최종 메트릭: {final_metrics}")

        # 최종 모델 저장
        self.save_model(f"final_model_task_{task_id}.pt")

    def train_epoch(self, train_loader: DataLoader, epoch: int, task_id: int):
        """
        단일 에포크 훈련

        Args:
            train_loader: 훈련 데이터 로더
            epoch: 현재 에포크
            task_id: 태스크 ID

        Returns:
            평균 손실과 정확도
        """
        # 모델만 학습 모드로 설정(분류기 제거)
        self.model.train()

        total_loss = 0
        correct = 0
        total = 0
        
        # 로그로만 진행상황 표시
        logger.info(f"Epoch {epoch} 학습 시작: 총 {len(train_loader)} 배치")
            
        # 기본 tqdm 설정만 사용 (호환성 최대화)
        pbar = tqdm(
            train_loader, 
            desc=f"Epoch {epoch}"
        )

        # 동적 배치 크기 적용
        current_batch_size = self.batch_manager.batch_size

        for batch_idx, batch_data in enumerate(pbar):
            try:
                # BatchEncoding 타입 또는 딕셔너리 형태의 데이터 처리
                if isinstance(batch_data, dict) or hasattr(batch_data, 'input_ids'):
                    # input_ids를 입력 데이터로 사용
                    data = batch_data.get('input_ids', None)
                    if data is None and 'input_ids' in batch_data:
                        data = batch_data['input_ids']
                    
                    # data가 여전히 None이면 배치 건너뛰기
                    if data is None:
                        logger.warning(f"배치 {batch_idx}에서 input_ids를 찾을 수 없습니다. 건너뜁니다.")
                        continue
                    
                    data = data.to(self.device)
                    
                    # 언어 모델링에서는 labels가 있을 수 있음
                    if 'labels' in batch_data:
                        target = batch_data['labels'].to(self.device)
                    else:
                        # input_ids를 target으로 사용
                        target = data
                    
                    # 디버그용 로그
                    if batch_idx == 0:
                        logger.info(f"배치 데이터 형태: {type(batch_data).__name__}, 키: {batch_data.keys() if hasattr(batch_data, 'keys') else 'N/A'}")
                        logger.info(f"입력 데이터 형태: {data.shape if hasattr(data, 'shape') else 'unknown'}")
                else:
                    # 기존 (data, target) 튜플 형태 처리
                    data, target = batch_data
                    data, target = data.to(self.device), target.to(self.device)

                # 학습 단계 수행
                if self.use_amp:
                    loss, acc = self.train_step_amp(data, target, task_id)
                else:
                    loss, acc = self.train_step(data, target, task_id)
                    
                # 배치 크기 업데이트 (OOM 없음)
                new_batch_size = self.batch_manager.update()
                if new_batch_size != current_batch_size:
                    current_batch_size = new_batch_size
                    logger.info(f"배치 크기 업데이트: {current_batch_size}")
                        
                # 프로그레스 바 통합 업데이트 (출력 빈도 최적화)
                if batch_idx % 50 == 0 or batch_idx == 0:  # 첫 배치와 50배치마다 프로그레스 업데이트
                    # 누적 손실과 정확도 계산
                    running_loss = total_loss / (batch_idx + 1)
                    running_acc = correct / max(total, 1)
                    
                    try:
                        # 표시와 postfix 모두 한 번에 업데이트
                        pbar.set_description(f"Epoch {epoch} [L: {loss:.4f}]")
                        pbar.set_postfix({
                            'loss': f"{running_loss:.4f}",
                            'acc': f"{running_acc:.4f}",
                            'lr': f"{self.optimizer.param_groups[0]['lr']:.2e}",
                            'bs': current_batch_size
                        }, refresh=False)
                    except Exception as e:
                        # tqdm 버전 호환성 오류 무시
                        pass
                    
                    # 디버그 로그 - 첫 배치와 100배치마다만 출력
                    if batch_idx == 0 or batch_idx % 100 == 0:
                        if isinstance(data, torch.Tensor):
                            logger.debug(f"데이터 텐서 정보 - 형태: {data.shape}, 디바이스: {data.device}, 타입: {data.dtype}")
                        elif hasattr(data, 'input_ids') and hasattr(data.input_ids, 'shape'):
                            logger.debug(f"입력 ID 텐서 정보 - 형태: {data.input_ids.shape}, 디바이스: {data.input_ids.device}, 타입: {data.input_ids.dtype}")

            except torch.cuda.OutOfMemoryError:
                # OOM 오류 발생 시 복구
                torch.cuda.empty_cache()
                logger.warning("OOM 오류 발생, 복구 중...")

                # 배치 크기 감소
                new_batch_size = self.batch_manager.update(oom_occurred=True)
                current_batch_size = new_batch_size

                # 해당 배치 건너뛰기
                continue
            except Exception as e:
                # "No inf checks" 오류는 학습에 영향을 주지 않으므로 무시
                if "No inf checks" in str(e):
                    # 오류 무시 - 학습 진행
                    pass
                else:
                    # 배치 처리 중 오류는 로그에 기록
                    logger.warning(f"배치 처리 중 오류 발생: {e}")
                    if self.global_step % 50 == 0:  # 로그 양 제한
                        if isinstance(batch_data, dict) or hasattr(batch_data, 'keys'):
                            logger.debug(f"배치 키: {list(batch_data.keys()) if hasattr(batch_data, 'keys') else 'N/A'}")
                        # 배치 데이터 구조 출력
                        logger.error(f"배치 데이터 구조: {type(batch_data)}")
                        # 데이터 샘플 출력
                        if hasattr(batch_data, 'input_ids'):
                            logger.info(f"input_ids 형태: {type(batch_data.input_ids)}, 크기: {batch_data.input_ids.shape if hasattr(batch_data.input_ids, 'shape') else 'N/A'}")
                continue

            # 손실 및 정확도 누적
            total_loss += loss
            # data.size(0) 대신 안전하게 배치 크기 추출
            batch_size = data.size(0) if hasattr(data, 'size') else 1
            correct += acc * batch_size
            total += batch_size

            # 진행률 표시 - 예외 처리 추가
            try:
                pbar.set_postfix({
                    'loss': f"{loss:.4f}",
                    'acc': f"{acc*100:.2f}%",
                    'lr': f"{self.scheduler.get_last_lr()[0]:.6f}"
                }, refresh=False)  # 성능 향상을 위해 refresh=False 추가
            except Exception as e:
                # tqdm 버전 호환성 오류 무시
                pass

            # 체크포인트 확인
            if self.spot_handler.should_checkpoint(self.global_step):
                self.spot_handler.save_checkpoint(
                    model=self.model,
                    optimizer=self.optimizer,
                    scheduler=self.scheduler,
                    scaler=self.scaler,
                    trainer=self,
                    step=self.global_step
                )

            # 전역 스텝 업데이트
            self.global_step += 1

        # 에포크 종료 후 최종 요약 출력
        logger.info(f"에포크 {epoch} 완료: 손실={total_loss/(batch_idx+1):.4f}, 정확도={correct/max(1,total):.4f}")
        
        # 최종 메트릭 계산
        return total_loss / len(train_loader), correct / total if total > 0 else 0

    def train_step(self, data: torch.Tensor, target: torch.Tensor, task_id: int):
        """
        단일 훈련 스텝 (표준 정밀도) - 언어 모델 내장 손실만 사용

        Args:
            data: 입력 데이터
            target: 타겟 레이블
            task_id: 현재 태스크 ID
            
        Returns:
            손실, 정확도
        """
        # 언어 모델링 손실 계산을 위해 labels 인자로 target 전달
        model_output = self.model(data, labels=target)
        
        # 손실 및 정확도 초기화
        loss = None
        acc = 0.0
        
        try:
            # 모델 내장 손실 사용
            if hasattr(model_output, 'loss') and model_output.loss is not None:
                loss = model_output.loss
                if self.global_step % 10 == 0:
                    logger.info(f"[INFO] 언어 모델 내장 손실 사용: {loss.item():.4f}")
                    
                # 정확도 계산 (선택적)
                if hasattr(model_output, 'logits'):
                    # 마지막 토큰 위치의 예측값만 사용하여 간단한 정확도 계산
                    if len(target.shape) > 1 and len(model_output.logits.shape) > 2:
                        # 실제 타겟의 마지막 유효한 토큰 인덱스 계산 (패딩 제외)
                        pred = model_output.logits[:, :-1].argmax(dim=-1)
                        valid_target = target[:, 1:]
                        valid_mask = (valid_target != -100) & (valid_target != 0)  # 패딩 토큰 및 특수 토큰 제외
                        
                        if valid_mask.sum() > 0:
                            correct = (pred == valid_target) & valid_mask
                            acc = correct.sum().float() / valid_mask.sum()
            else:
                # 내장 손실이 없는 경우 - 예외 상황
                available_attrs = [attr for attr in dir(model_output) if not attr.startswith('_')]
                logger.error(f"model_output 속성: {available_attrs}")
                raise ValueError(f"모델에서 언어 모델링 손실을 찾을 수 없음: {type(model_output)}")
                    
                    
                # logits이 있는 경우 직접 손실 계산    
        except Exception as e:
            logger.error(f"배치 처리 중 오류 발생: {e}")
            logger.error(f"입력 데이터 구조: {type(data)}")
            
            # 대체 접근 방법 시도
            logger.warning(f"기본 방식 손실 계산 중 오류 발생: {str(e)}")
            logger.warning("내장 손실 계산 재시도...")
            
            # 데이터에 레이블이 포함된 경우 (HuggingFace 형식)
            if isinstance(data, dict) and 'labels' in data:
                model_output = self.model(**data)  # 레이블 포함하여 다시 계산
            else:
                # 다양한 방식 시도
                try:
                    model_output = self.model(data, labels=target)  # 레이블 명시적 전달
                except Exception as inner_e:
                    logger.error(f"레이블 전달 시도 실패: {inner_e}")
                    # 마지막 시도: 레이블 없이 실행 후 손실 수동 계산
                    model_output = self.model(data)
                
            if hasattr(model_output, 'loss'):
                loss = model_output.loss
                logger.info(f"[복구] 내장 손실 사용: {loss.item():.4f}")
            else:
                logger.error("내장 손실 계산 실패, 학습 중단 필요")
                raise RuntimeError("모델의 언어 모델링 손실 계산 실패")
        
        # 학습률 스케줄러 스텝 (있는 경우)
        if self.scheduler is not None:
            self.scheduler.step()

        # MER 메모리 업데이트
        self.mer.reservoir_sampling(data, target)

        # 결과 반환: 손실 값과 정확도(있는 경우)
        return loss.item(), acc

    def train_step_amp(self, data, target, task_id: int):
        """
        단일 훈련 스텝 (혼합 정밀도) - 언어 모델 내장 손실만 사용

        Args:
            data: 입력 데이터 (torch.Tensor 또는 transformers.BatchEncoding)
            target: 타겟 레이블 (torch.Tensor 또는 transformers.BatchEncoding)
            task_id: 현재 태스크 ID

        Returns:
            손실 값, 정확도
        """
        self.optimizer.zero_grad()

        # 데이터 타입 체크 및 처리
        from transformers.tokenization_utils_base import BatchEncoding
        
        # BatchEncoding 객체 처리
        if isinstance(data, BatchEncoding):
            # BatchEncoding 로깅 최적화: 초기 한 번만 info로 기록, 이후에는 매우 낮은 빈도로 debug로 기록
            if self.global_step == 0:  # 학습 첫 시작시에만 info 레벨로 기록
                logger.info(f"BatchEncoding 데이터 처리: 키={list(data.keys())}, 디바이스={data.get('input_ids').device if 'input_ids' in data else '알 수 없음'}")
            elif self.global_step % 5000 == 0:  # 이후에는 5000 스텝마다 debug 레벨로만 기록
                logger.debug(f"BatchEncoding 데이터 처리: 키={list(data.keys())}, 디바이스={data.get('input_ids').device if 'input_ids' in data else '알 수 없음'}")
            
            # 필요한 경우에만 input_ids 추출
            if 'input_ids' in data:
                input_tensor = data['input_ids']
            else:
                available_keys = list(data.keys())
                logger.error(f"BatchEncoding에서 input_ids를 찾을 수 없음. 사용 가능한 키: {available_keys}")
                raise ValueError(f"BatchEncoding에 input_ids가 없습니다. 사용 가능한 키: {available_keys}")
        else:
            # 이미 텐서인 경우 그대로 사용
            input_tensor = data
        
        # target도 동일하게 처리
        if isinstance(target, BatchEncoding):
            if 'labels' in target:
                labels_tensor = target['labels']
            elif 'input_ids' in target:
                # fallback: input_ids를 labels로 사용
                labels_tensor = target['input_ids']
            else:
                available_keys = list(target.keys())
                logger.error(f"BatchEncoding에서 labels를 찾을 수 없음. 사용 가능한 키: {available_keys}")
                raise ValueError(f"BatchEncoding에 labels가 없습니다. 사용 가능한 키: {available_keys}")
        else:
            # 이미 텐서인 경우 그대로 사용
            labels_tensor = target

        # 혼합 정밀도 연산 - 최신 PyTorch API 형식 사용
        with autocast('cuda'):
            # 언어 모델링 손실 계산을 위해 labels 인자로 target 전달
            # DeepSeek-Coder 및 다른 Hugging Face 모델은 labels를 통한 직접 손실 계산 지원
            model_output = self.model(input_tensor, labels=labels_tensor)
            
            # 언어 모델 내장 손실 사용
            try:
                # 1. 모델이 내장 손실을 제공하는 경우 (허깅페이스 transformers 모델)
                if hasattr(model_output, 'loss') and model_output.loss is not None:
                    # 손실을 추출하고 그래디언트 추적을 위해 명시적으로 requires_grad=True 설정
                    loss = model_output.loss
                    
                    # 그래디언트 추적이 필요하지만 없는 경우
                    if not loss.requires_grad:
                        # 새로운 텐서로 복사하고 그래디언트 추적 활성화
                        loss = loss.detach().clone().requires_grad_(True)
                    
                    # 로그 출력 빈도 조절 (500배치마다 한 번만 출력)
                    if self.global_step % 500 == 0:
                        logger.info(f"LM 내장 손실 사용: {loss.item():.4f}, requires_grad={loss.requires_grad}, 학습률={self.scheduler.get_last_lr()[0]:.6f}")
                else:
                    # 내장 손실이 없는 경우 - 로깅 후 예외 발생
                    available_attrs = [attr for attr in dir(model_output) if not attr.startswith('_')]
                    logger.error(f"model_output 속성: {available_attrs}")
                    raise ValueError(f"모델에서 언어 모델링 손실을 찾을 수 없음: {type(model_output)}")
                    
            except Exception as e:
                logger.error(f"손실 계산 중 오류 발생: {e}")
                logger.error(f"model_output 타입: {type(model_output)}")
                raise RuntimeError(f"LM 손실 계산 실패: {e}")
                
            # 정확도 계산 (선택적) - 완전한 분류기 대신 단순한 정확도 추정 사용
            acc = 0.0
            try:
                if hasattr(model_output, 'logits'):
                    # 마지막 토큰 위치의 예측값만 사용하여 간단한 정확도 계산 
                    # (실제 LM 성능은 perplexity 등으로 별도 평가해야 함)
                    if len(target.shape) > 1 and len(model_output.logits.shape) > 2:
                        # 실제 타겟의 마지막 유효한 토큰 인덱스 계산 (패딩 제외)
                        pred = model_output.logits[:, :-1].argmax(dim=-1)
                        valid_target = target[:, 1:]
                        valid_mask = (valid_target != -100) & (valid_target != 0)  # 패딩 토큰 및 특수 토큰 제외
                        
                        if valid_mask.sum() > 0:
                            correct = (pred == valid_target) & valid_mask
                            acc = correct.sum().float() / valid_mask.sum()
            except Exception as e:
                logger.warning(f"정확도 계산 중 오류 발생 (무시됨): {e}")
                acc = 0.0
                
            # EWC 손실 (이전 태스크 보존) - 필요한 경우
            try:
                # EWC 손실 계산 전에 자원 확보
                if hasattr(self, 'ewc') and self.ewc is not None:
                    ewc_loss = self.ewc.compute_ewc_loss()
                    
                    # 그래디언트 추적이 필요하면 자동으로 설정
                    if not ewc_loss.requires_grad and ewc_loss.item() != 0.0:
                        # 가능한 한 조용히 처리 - 로그 없이 자동 변환
                        ewc_loss = ewc_loss.detach().clone().requires_grad_(True)
                    
                    # 손실값 유효성 검사
                    if torch.isfinite(ewc_loss).all() and torch.isfinite(loss).all():
                        loss = loss + ewc_loss
                        if self.global_step % 20 == 0 and ewc_loss.item() > 0.0001:
                            # 유의미한 EWC 손실이 있을 때만 기록
                            logger.info(f"LM 손실: {loss.item():.4f}, EWC 손실: {ewc_loss.item():.4f}")
                else:
                    # EWC가 없는 경우 아무것도 하지 않음
                    pass
            except Exception as e:
                # 중요하지 않은 예외는 조용히 무시
                pass

        # 역전파 전에 손실이 그래디언트 추적을 요구하는지 확인
        try:
            if not loss.requires_grad:
                logger.warning(f"[안전 조치] 손실이 그래디언트 추적을 요구하지 않음. 명시적으로 requires_grad=True 설정")
                loss = loss.detach().clone().requires_grad_(True)
            
            # 스케일링된 역전파
            scaled_loss = self.scaler.scale(loss)
            scaled_loss.backward()
            
            if self.global_step % 50 == 0:
                logger.debug(f"backward 호출 성공: 손실={loss.item():.4f}, requires_grad={loss.requires_grad}")
        except Exception as e:
            logger.error(f"역전파 오류: {e}")
            # 간단한 대체 역전파 시도 (응급 조치)
            try:
                if hasattr(loss, 'backward'):
                    loss.backward()
                    logger.info("대체 역전파 방법 사용")
            except Exception as e2:
                logger.error(f"대체 역전파도 실패: {e2}")
                # 실패하면 현재 배치 건너뛼

        # 그래디언트 클리핑 (선택적)
        if self.config.get('grad_clip', 0) > 0:
            try:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), 
                    self.config.get('grad_clip', 1.0)
                )
            except RuntimeError as e:
                logger.warning(f"그래디언트 클리핑 중 오류 (무시함): {e}")

        # 스케일러를 통한 옵티마이저 스텝 - 오류 처리 개선
        try:
            # 옵티마이저 스텝 수행
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            # 학습률 스케쥴러 업데이트
            if self.scheduler is not None:
                self.scheduler.step()
                
        except RuntimeError as e:
            error_msg = str(e)
            if "No inf checks" in error_msg:
                # 이 오류는 학습에 실질적 영향이 없으므로 로그 없이 무시
                pass
            elif "element 0 of tensors does not require grad" in error_msg:
                # 그래디언트 추적 문제 - 다음 배치에서 문제 해결 시도
                logger.warning("[문제 해결 시도] 텐서에 그래디언트 추적이 활성화되지 않은 문제")
                # 옵티마이저 상태 초기화 (경계 상태)
                self.optimizer.zero_grad()
            else:
                # 다른 중요한 오류는 기록 후 넘김
                logger.warning(f"옵티마이저 스텝 중 오류 발생: {e}")
                # 특정 유형의 오류만 무시하고 다음 배치로 진행
                if "CUDA out of memory" in error_msg or "c10:" in error_msg:
                    logger.error("추가 디버그 정보: CUDA 특정 오류 문제 무시")
                    # 메모리 정리
                    torch.cuda.empty_cache()
                else:
                    raise

        # 학습률 스케줄러 스텝
        self.scheduler.step()

        # MER 메모리 업데이트 (한번만 실행)
        self.mer.reservoir_sampling(data, target)

        return loss.item(), acc

    def evaluate(self, data_loader: DataLoader):
        """
        모델 평가

        Args:
            data_loader: 평가 데이터 로더

        Returns:
            평균 손실, 정확도
        """
        self.model.eval()

        total_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for data, target in data_loader:
                data, target = data.to(self.device), target.to(self.device)

                try:
                    # 언어 모델의 내장 손실 함수 사용
                    outputs = self.model(data, labels=target)
                    loss = outputs.loss
                    total_loss += loss.item()

                    # 로그엣과 평균 정확도 계산 (간단한 버전)
                    logits = outputs.logits  # [batch_size, sequence_length, vocab_size]
                    shift_logits = logits[:, :-1, :].contiguous()
                    shift_labels = target[:, 1:].contiguous()
                    
                    # 참고: 여기서의 정확도는 단어 레벨 정확도의 대략적인 추정
                    flattened_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
                    flattened_shift_labels = shift_labels.view(-1)
                    
                    # 유효한 토큰에 대해서만 정확도 계산 (패딩 토큰 제외)
                    valid_indices = (flattened_shift_labels != -100)
                    if valid_indices.sum() > 0:
                        valid_logits = flattened_shift_logits[valid_indices]
                        valid_labels = flattened_shift_labels[valid_indices]
                        
                        predictions = valid_logits.argmax(dim=-1)
                        correct += (predictions == valid_labels).sum().item()
                        total += valid_indices.sum().item()
                    
                except Exception as e:
                    logger.error(f"평가 중 오류 발생: {str(e)}")
                    continue

        # 평균 손실 및 정확도 계산
        avg_loss = total_loss / len(data_loader) if len(data_loader) > 0 else 0
        accuracy = correct / total if total > 0 else 0

        return avg_loss, accuracy

    def log_metrics(self, epoch: int, train_loss: float, train_acc: float, 
                 val_loss: float, val_acc: float, test_results: Dict[int, float],
                 metrics: Dict[str, float]):
        """
        메트릭 로깅

        Args:
            epoch: 현재 에포크
            train_loss: 훈련 손실
            train_acc: 훈련 정확도
            val_loss: 검증 손실
            val_acc: 검증 정확도
            test_results: 이전 태스크 테스트 결과
            metrics: 계산된 메트릭
        """
        # 콘솔 로깅
        log_str = f"Epoch {epoch}: "
        log_str += f"Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.2f}% | "
        log_str += f"Val Loss: {val_loss:.4f}, Acc: {val_acc*100:.2f}% | "

        for task_id, acc in test_results.items():
            log_str += f"Task {task_id} Acc: {acc*100:.2f}% | "

        log_str += f"ACC: {metrics['ACC']*100:.2f}%, BWT: {metrics['BWT']*100:.2f}%, "
        log_str += f"Stability Gap: {metrics['stability_gap']*100:.2f}%"

        logger.info(log_str)

        # TensorBoard 로깅
        step = epoch + self.current_task_id * self.config.get('max_epochs', 10)

        self.writer.add_scalar('Loss/train', train_loss, step)
        self.writer.add_scalar('Loss/val', val_loss, step)
        self.writer.add_scalar('Accuracy/train', train_acc, step)
        self.writer.add_scalar('Accuracy/val', val_acc, step)

        for task_id, acc in test_results.items():
            self.writer.add_scalar(f'Accuracy/task_{task_id}', acc, step)

        self.writer.add_scalar('Metrics/ACC', metrics['ACC'], step)
        self.writer.add_scalar('Metrics/BWT', metrics['BWT'], step)
        self.writer.add_scalar('Metrics/FWT', metrics['FWT'], step)
        self.writer.add_scalar('Metrics/stability_gap', metrics['stability_gap'], step)
        self.writer.add_scalar('Metrics/recovery_ratio', metrics['recovery_ratio'], step)

        # 학습률 로깅
        self.writer.add_scalar('LR', self.scheduler.get_last_lr()[0], step)

    def save_model(self, filename: str):
        """
        모델 저장

        Args:
            filename: 저장할 파일 이름
        """
        save_path = os.path.join(self.config.get('checkpoint_dir', 'checkpoints'), filename)
        os.makedirs(os.path.dirname(save_path), exist_ok=True)

        state_dict = {
            'model': self.model.state_dict(),
            # 분류기 제거 - 언어 모델 내장 손실만 사용
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict() if self.scheduler else None,
            'scaler': self.scaler.state_dict() if self.scaler else None,
            'metrics': self.metrics.get_state_dict(),
            'config': self.config,
            'global_step': self.global_step,
            'current_task_id': self.current_task_id,
            'best_accuracies': self.best_accuracies,
            'ewc': self.ewc.get_state_dict(),
            'mer': self.mer.get_state_dict()
        }

        torch.save(state_dict, save_path)
        logger.info(f"모델이 저장되었습니다: {save_path}")

    def load_model(self, filename: str):
        """
        모델 로드

        Args:
            filename: 로드할 파일 이름
        """
        load_path = os.path.join(self.config.get('checkpoint_dir', 'checkpoints'), filename)
        if not os.path.exists(load_path):
            logger.warning(f"로드할 모델 파일이 없습니다: {load_path}")
            return False

        state_dict = torch.load(load_path, map_location=self.device)

        self.model.load_state_dict(state_dict['model'])
        # 분류기 로드 제거 - 언어 모델만 사용
        self.optimizer.load_state_dict(state_dict['optimizer'])

        if 'scheduler' in state_dict and state_dict['scheduler'] and self.scheduler:
            self.scheduler.load_state_dict(state_dict['scheduler'])

        if 'scaler' in state_dict and state_dict['scaler'] and self.scaler:
            self.scaler.load_state_dict(state_dict['scaler'])

        if 'metrics' in state_dict:
            self.metrics.load_state_dict(state_dict['metrics'])

        if 'ewc' in state_dict:
            self.ewc.load_state_dict(state_dict['ewc'])

        if 'mer' in state_dict:
            self.mer.load_state_dict(state_dict['mer'])

        self.global_step = state_dict.get('global_step', 0)
        self.current_task_id = state_dict.get('current_task_id', 0)
        self.best_accuracies = state_dict.get('best_accuracies', {})

        logger.info(f"모델을 로드했습니다: {load_path}")
        return True

    def get_state_dict(self):
        """상태 저장을 위한 딕셔너리 반환"""
        return {
            'metrics': self.metrics.get_state_dict(),
            'global_step': self.global_step,
            'current_task_id': self.current_task_id,
            'best_accuracies': self.best_accuracies,
            'ewc': self.ewc.get_state_dict(),
            'mer': self.mer.get_state_dict()
        }

    def load_state_dict(self, state_dict):
        """저장된 상태 로드"""
        if 'metrics' in state_dict:
            self.metrics.load_state_dict(state_dict['metrics'])

        if 'ewc' in state_dict:
            self.ewc.load_state_dict(state_dict['ewc'])

        if 'mer' in state_dict:
            self.mer.load_state_dict(state_dict['mer'])

        self.global_step = state_dict.get('global_step', 0)
        self.current_task_id = state_dict.get('current_task_id', 0)
        self.best_accuracies = state_dict.get('best_accuracies', {})

# 메인 함수
# wandb 명시적 비활성화 (스크립트 맨 위에 추가)
os.environ["WANDB_DISABLED"] = "true"

def main():
    """메인 실행 함수"""
    parser = argparse.ArgumentParser(description='CapaBoost + Continual Learning 통합 시스템')
    parser.add_argument('--config', type=str, default='config.yaml', help='설정 파일 경로')
    parser.add_argument('--mode', type=str, choices=['complete', 'prompt', 'comment', 'error_fix'], default='prompt',
                        help='학습 모드: complete(자동완성), prompt(일반 프롬프트), comment(주석), error_fix(오류 수정)')
    parser.add_argument('--seed', type=int, default=42, help='랜덤 시드')
    parser.add_argument('--resume', action='store_true', help='이전 체크포인트에서 재개')
    parser.add_argument('--resume_from_checkpoint', type=str, help='특정 체크포인트 경로에서 재개')
    parser.add_argument('--enable_capaboost', action='store_true', help='CapaBoost 기능 활성화')
    args = parser.parse_args()

    # 설정 파일 로드
    with open(args.config, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)

    # 명령행 인수로 설정 덮어쓰기
    if args.seed:
        config['seed'] = args.seed
        
    # 학습 모드 추가
    config['mode'] = args.mode
    logger.info(f"\ud559\uc2b5 \ubaa8\ub4dc: {args.mode}")
    
    # 모드별 출력 경로 설정
    if args.mode == 'complete':
        config['output_dir'] = config.get('output_dir', '../models/autocomplete-finetuned/')
        logger.info("[1차] 코드 자동완성 학습 모드 (FIM 형식 지원)")
    elif args.mode == 'prompt':
        config['output_dir'] = config.get('output_dir', '../models/prompt-finetuned/')
        logger.info("[2차] 일반 프롬프트 기반 코드 생성 모드")
    elif args.mode == 'comment':
        config['output_dir'] = config.get('output_dir', '../models/comment-finetuned/')
        logger.info("[3차] 주석 기반 코드 생성 모드")
    elif args.mode == 'error_fix':
        config['output_dir'] = config.get('output_dir', '../models/error-fix-finetuned/')
        logger.info("[4차] 코드 오류 설명 및 수정 모드")

    # CapaBoost 활성화 여부 로깅 및 설정에 추가
    capaboost_enabled = args.enable_capaboost or (os.environ.get('ENABLE_CAPABOOST', '0') == '1')
    if capaboost_enabled:
        logger.info("CapaBoost 기능이 활성화되었습니다.")
        config['enable_capaboost'] = True
    else:
        logger.info("CapaBoost 기능이 비활성화되었습니다.")
        config['enable_capaboost'] = False
        
    # 통합 학습기 초기화
    learner = ContinualLearner(config)

    # 체크포인트 경로 설정
    checkpoint_path = None
    if args.resume_from_checkpoint:
        checkpoint_path = args.resume_from_checkpoint
        logger.info(f"\uccb4\ud06c\ud3ec\uc778\ud2b8 \uacbd\ub85c\uc5d0\uc11c \uc7ac\uac1c: {checkpoint_path}")
    elif args.resume:
        checkpoint_path = "latest"  # 최근 체크포인트에서 재개
        logger.info("\ucd5c\uc2e0 \uccb4\ud06c\ud3ec\uc778\ud2b8\uc5d0\uc11c \uc7ac\uac1c")
    
    # 체크포인트 로드
    if checkpoint_path:
        if checkpoint_path == "latest":
            learner.spot_handler.load_latest_checkpoint(
                model=learner.model,
                optimizer=learner.optimizer,
                scheduler=learner.scheduler,
                scaler=learner.scaler,
                trainer=learner
            )
        else:
            learner.spot_handler.load_checkpoint(
                checkpoint_path,
                model=learner.model,
                optimizer=learner.optimizer,
                scheduler=learner.scheduler,
                scaler=learner.scaler,
                trainer=learner
            )

        # 모드별 데이터 전처리 및 포맷 적용
    logger.info(f"모드별 데이터 처리: {args.mode} 모드용 데이터셋 준비 중...")
    
    # 데이터 로더 생성 - 모드별 처리 로직 적용
    # train.py의 데이터 형식 처리 로직을 가져와서 구현
    train_dataset = None
    val_dataset = None
    
    # 데이터 셋 경로 설정 - 사용자 지정 경로로 변경
    data_path = os.path.expanduser("~/deepseek-coder/data")
    logger.info(f"데이터 디렉토리 경로: {data_path}")
    
    # 모드별 추가 로깅
    if args.mode == 'complete':
        logger.info(f"자동완성(FIM) 모드 사용")
    elif args.mode == 'prompt':
        logger.info(f"프롬프트 생성 모드 사용")
    elif args.mode == 'comment': 
        logger.info(f"주석 기반 생성 모드 사용")
    elif args.mode == 'error_fix':
        logger.info(f"오류 수정 모드 사용")
    
    # 데이터셋 로드 (datasets 라이브러리 사용)
    if not DATASETS_AVAILABLE:
        raise ImportError("datasets 라이브러리를 찾을 수 없습니다. 'pip install datasets'를 실행하여 설치하세요.")
        
    try:
        logger.info(f"데이터셋 로드 중: {data_path}")
        
        # 데이터셋 파일 존재 여부 확인 - 훈련 데이터만 체크 (.jsonl 형식)
        train_path = os.path.join(data_path, 'train.jsonl')
        
        if not os.path.exists(train_path):
            logger.error(f"훈련 데이터 파일을 찾을 수 없습니다: {train_path}")
            raise FileNotFoundError(f"훈련 데이터 파일을 찾을 수 없습니다: {train_path}")
            
        # 훈련 데이터 파일만 사용 (검증 데이터는 후에 분리)
        data_files = {'train': train_path}
            
        raw_dataset = load_dataset(
            'json', 
            data_files=data_files,
            cache_dir=config.get('cache_dir', '.cache'),
            # jsonl 형식 처리를 위한 옵션
            streaming=False  # 지정한 경우에만 False로 설정하여 메모리에 로드
        )
        
        # 훈련/검증 분리
        validation_split_percentage = config.get('validation_split_percentage', 10)
        logger.info(f"훈련 데이터에서 검증 데이터 {validation_split_percentage}% 분리 시작")
        
        # 기존 validation 세트가 없는 경우에만 분리 수행
        if 'validation' not in raw_dataset:
            split_datasets = raw_dataset["train"].train_test_split(
                test_size=validation_split_percentage / 100,
                seed=config.get('seed', 42)
            )
            raw_dataset["train"] = split_datasets["train"]
            raw_dataset["validation"] = split_datasets["test"]
            
        logger.info(f"최종 데이터셋 크기 - 훈련: {len(raw_dataset['train'])}, 검증: {len(raw_dataset['validation'])}")
        
        
        # 모드별 적합한 데이터 전처리
        # 기본 DeepSeek-Coder 모델 경로 지정 (절대경로 사용)
        default_model_path = "/home/ubuntu/deepseek-coder/model_cache/deepseek-coder-6.7b-instruct"
        
        # 구성에서 모델 경로 가져오기 (model_name 또는 model_name_or_path)
        model_path = config.get('model_name', config.get('model_name_or_path', default_model_path))
        
        logger.info(f"토크나이저 로드 중: {model_path}")
        
        try:
            # 로컬 경로인지 확인 (절대 경로 또는 ~로 시작하는 경로)
            is_local_path = os.path.isabs(model_path) or model_path.startswith('~') or model_path.startswith('/')
            
            # 로컬 경로면 local_files_only=True 옵션 추가
            tokenizer = AutoTokenizer.from_pretrained(
                model_path,
                cache_dir=config.get('cache_dir', '.cache'),
                use_fast=True,
                local_files_only=is_local_path,  # 로컬 파일만 사용
                trust_remote_code=True  # 안전한 환경에서 원격 코드 신뢰
            )
            logger.info(f"토크나이저 로드 성공: {model_path}")
        except Exception as e:
            logger.error(f"토크나이저 로드 오류: {e}")
            logger.info(f"기본 모델 {default_model_path}로 재시도 중...")
            
            # 기본 DeepSeek-Coder 모델로 재시도 (로컬 파일 사용 옵션 추가)
            tokenizer = AutoTokenizer.from_pretrained(
                default_model_path,
                cache_dir=config.get('cache_dir', '.cache'),
                use_fast=True,
                local_files_only=True,  # 로컬 파일만 사용
                trust_remote_code=True  # 안전한 환경에서 원격 코드 신뢰
            )
        
        # 최대 길이 설정
        max_length = config.get('max_length', 2048)
        
        # 데이터 형식을 자동으로 감지하는 함수 추가
        def detect_format(example):
            """데이터 형식을 감지하는 함수"""
            if "messages" in example:
                return "chat"
            elif "prompt" in example and "completion" in example:
                return "prompt_completion"
            elif "prefix_code" in example and "suffix_code" in example and "comment" in example and "target_code" in example:
                # 주석 기반 코드 생성 형식 - 새로운 데이터 구조
                return "comment_to_code"
            elif "error_context" in example and "explanation" in example:
                return "error_explanation"
            elif "error_context" in example and "fixed_code_snippet" in example:
                # error_fix 형식 감지 개선 - buggy_code_snippet이 error_context 내부에 있는 경우도 처리
                error_context = example["error_context"]
                if isinstance(error_context, dict) and "buggy_code_snippet" in error_context:
                    return "error_fix"
                # 또는 buggy_code_snippet이 최상위 필드에 있는 경우
                elif "buggy_code_snippet" in example:
                    return "error_fix"
            elif "instruction" in example and "input" in example and "output" in example:
                return "instruction_input_output"
            elif "content" in example and args.mode == "complete":
                # complete 모드용 단순 content 필드만 있는 경우 (FIM 형식일 수 있음)
                return "complete"
            elif "comment" in example and "code" in example and args.mode == "comment":
                # 주석과 코드가 있는 단순 형식
                return "comment_code"
            else:
                # 데이터 구조 로깅
                logger.warning(f"⚠️ 알 수 없는 데이터 형식: {list(example.keys())}")
                return "unknown"
                
        # 데이터 형식 감지 (첫 번째 샘플로 판단)
        try:
            sample = raw_dataset["train"][0]
            data_format = detect_format(sample)
            logger.info(f"✅ 감지된 데이터 형식: {data_format}")
        except Exception as e:
            logger.error(f"데이터 형식 감지 중 오류 발생: {e}")
            data_format = "unknown"
        
        # FIM 형식 사용 여부를 미리 확인 (로그 중복 방지)
        has_fim_format = False
        if data_format == "prompt_completion" and args.mode == "comment":
            # 첫 몇 개 샘플 검사하여 FIM 태그 사용 여부 확인
            sample_check_count = min(5, len(raw_dataset["train"]))
            for i in range(sample_check_count):
                sample_prompt = raw_dataset["train"][i]["prompt"]
                if "<|fim begin|>" in sample_prompt or "<|fim hole|>" in sample_prompt or "<|fim end|>" in sample_prompt:
                    logger.info("✅ FIM 형식 주석 기반 모델 데이터 감지됨")
                    has_fim_format = True
                    break
        
        # 모드별 전처리 함수 정의
        def process_complete_data(examples):
            # FIM (Fill In the Middle) 형식 처리
            results = tokenizer(
                examples['content'],
                truncation=True,
                max_length=max_length,
                return_attention_mask=False
            )
            return results
            
        def process_prompt_data(examples):
            # 프롬프트 기반 생성 형식 처리
            results = tokenizer(
                examples['instruction'] + '\n' + examples['input'] + '\n' + examples['output'],
                truncation=True,
                max_length=max_length,
                return_attention_mask=False
            )
            return results
            
        def process_comment_data(examples):
            # 주석 기반 코드 생성 처리
            results = tokenizer(
                examples['comment'] + '\n' + examples['code'],
                truncation=True,
                max_length=max_length,
                return_attention_mask=False
            )
            return results
            
        def process_error_fix_data(examples):
            # 오류 수정 데이터 처리
            results = tokenizer(
                examples['error_description'] + '\n' + examples['buggy_code'] + '\n' + examples['fixed_code'],
                truncation=True,
                max_length=max_length,
                return_attention_mask=False
            )
            return results
        
        # 향상된 토큰화 함수 - train.py의 로직 통합
        def tokenize_function(examples):
            # 모드별, 데이터 형식별 처리를 통합
            texts = []
            
            # 샘플 수 결정 (다양한 형식 지원)
            sample_count = 0
            if "content" in examples and args.mode == 'complete':
                sample_count = len(examples["content"])
            elif "messages" in examples:
                sample_count = len(examples["messages"])
            elif "prompt" in examples:
                sample_count = len(examples["prompt"])
            elif "prefix_code" in examples:
                sample_count = len(examples["prefix_code"])
            elif "instruction" in examples:
                sample_count = len(examples["instruction"])
            elif "comment" in examples and "code" in examples:
                sample_count = len(examples["comment"])
            elif "error_description" in examples and "buggy_code" in examples:
                sample_count = len(examples["error_description"])
            else:
                logger.error("❌ 지원되지 않는 데이터 형식")
                return {"input_ids": []}
            
            # 각 샘플 처리
            for i in range(sample_count):
                # 데이터 형식 및 모드별 처리
                if data_format == "complete" or (args.mode == 'complete' and "content" in examples):
                    # 자동완성(complete) 모드
                    if "content" in examples:
                        text = examples["content"][i]
                        texts.append(text)
                
                elif data_format == "prompt_completion":
                    # 주석 키워드 기반 코드 생성 형식
                    prompt = examples["prompt"][i] if "prompt" in examples else ""
                    completion = examples["completion"][i] if "completion" in examples else ""
                    
                    # FIM 형식 처리 (3차 주석 기반 모델)
                    if args.mode == "comment" and ("<|fim begin|>" in prompt or "<|fim hole|>" in prompt or "<|fim end|>" in prompt):
                        # FIM 형식 그대로 유지하고 ChatML 형식으로 래핑
                        text = f"assistant\n{prompt}\n\nassistant\n{completion}\n"
                    else:
                        # 일반 프롬프트를 사용자 메시지로, 완성을 어시스턴트 메시지로 변환
                        text = f"assistant\n{prompt}\n\nassistant\n{completion}\n"
                    
                    texts.append(text)
                        
                elif data_format == "comment_to_code" or data_format == "comment_code" or args.mode == 'comment':
                    # 주석 기반 코드 생성 형식
                    if data_format == "comment_to_code":
                        prefix_code = examples["prefix_code"][i] if "prefix_code" in examples else ""
                        suffix_code = examples["suffix_code"][i] if "suffix_code" in examples else ""
                        comment = examples["comment"][i]
                        target_code = examples["target_code"][i] if "target_code" in examples else ""
                        
                        # 사용자 입력 구성 (주석과 코드 컨텍스트)
                        user_text = f"주석에 따라 적절한 코드를 생성해주세요.\n\n"
                        user_text += f"주석: {comment}\n\n"
                        user_text += "\n이전 코드:\n"
                        user_text += f"{prefix_code}\n"
                        user_text += "// 여기에 코드를 삽입해야 함 //\n"
                        user_text += f"{suffix_code}"
                        
                        # ChatML 형식으로 변환
                        text = f"assistant\n{user_text}\n\nassistant\n{target_code}\n"
                    else:
                        # 간단한 주석-코드 형식
                        comment = examples["comment"][i]
                        code = examples["code"][i]
                        text = f"assistant\n주석: {comment}\n\nassistant\n{code}\n"
                    
                    texts.append(text)
                    
                elif data_format == "error_fix" or args.mode == 'error_fix':
                    # 에러 수정 형식
                    if "error_description" in examples and "buggy_code" in examples and "fixed_code" in examples:
                        # 단순 에러 수정 형식
                        error_desc = examples["error_description"][i]
                        buggy_code = examples["buggy_code"][i]
                        fixed_code = examples["fixed_code"][i]
                        
                        text = f"assistant\n오류 설명: {error_desc}\n\nassistant\n{buggy_code}\n\nassistant\n{fixed_code}\n"
                        texts.append(text)
                    elif "error_context" in examples:
                        # 복잡한 에러 수정 형식
                        error_context = examples["error_context"][i]
                        fixed_code = examples["fixed_code_snippet"][i] if "fixed_code_snippet" in examples else ""
                        
                        # 에러 컨텍스트 정보 구성
                        error_log = error_context.get("error_log", "") if isinstance(error_context, dict) else ""
                        language = error_context.get("language", "") if isinstance(error_context, dict) else ""
                        
                        # buggy_code가 error_context 안에 있는 경우와 외부에 있는 경우 모두 처리
                        if isinstance(error_context, dict) and "buggy_code_snippet" in error_context:
                            buggy_code = error_context["buggy_code_snippet"]
                        elif "buggy_code_snippet" in examples:
                            buggy_code = examples["buggy_code_snippet"][i]
                        else:
                            buggy_code = ""
                            logger.warning("⚠️ buggy_code_snippet을 찾을 수 없습니다")
                        
                        # 사용자 입력 구성
                        user_text = f"다음 {language} 코드의 에러를 수정해주세요:\n\n에러 로그:\n{error_log}\n\n코드:\n{buggy_code}"
                        text = f"assistant\n{user_text}\n\nassistant\n{fixed_code}\n"
                        texts.append(text)
                else:
                    # 알 수 없는 형식은 원시 데이터 텍스트를 그대로 사용
                    logger.warning(f"⚠️ 지원되지 않는 데이터 형식: {data_format}, 모드: {args.mode}")
                    # 첫 번째 필드만 사용
                    if len(examples) > 0:
                        first_key = list(examples.keys())[0]
                        text = str(examples[first_key][i])
                        texts.append(text)
                    else:
                        texts.append("")
            
            # 토크나이징
            model_inputs = tokenizer(
                texts,
                truncation=True,
                max_length=max_length,
                return_attention_mask=False,
                padding=False
            )
            
            return model_inputs
            
        # 모드와 데이터 형식에 따른 처리 함수 선택
        logger.info(f"모드: {args.mode}, 감지된 데이터 형식: {data_format}")
        
        # 데이터셋 전처리 적용 - 향상된 토큰화 함수 사용
        # 데이터셋 토큰화 로깅 개선
        logger.info(f"✨ {args.mode} 데이터셋 토큰화 시작 (num_proc={config.get('preprocessing_num_workers', 4)})")
        
        # tqdm 설정 관련 코드
        # 라이브러리 버전 호환성 문제로 인해 tqdm_kwargs 매개변수를 사용하지 않음
        for bar in tqdm._instances:
            bar.close()
        
        # 토큰화 실행 - 기본 tqdm 설정 사용
        # 특히 AWS의 datasets 라이브러리에서는 tqdm_kwargs를 지원하지 않을 수 있음
        logger.info("토큰화 시작 - 토큰화 진행중 이후 로그를 확인하세요")
        tokenized_dataset = raw_dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=raw_dataset['train'].column_names,
            num_proc=config.get('preprocessing_num_workers', 4),
            desc=f"{args.mode} 데이터셋 토큰화 중"
            # tqdm_kwargs 매개변수 제거 - 호환성 문제 해결
        )
        logger.info(f"✅ {args.mode} 데이터셋 토큰화 완료: 훈련={len(tokenized_dataset['train'])}개, 검증={len(tokenized_dataset['validation'])}개")
        
        # DataLoader 생성을 위한 준비
        train_dataset = tokenized_dataset['train']
        val_dataset = tokenized_dataset['validation']
        
        # DataCollatorForLanguageModeling 로드
        try:
            from transformers import DataCollatorForLanguageModeling
            logger.info("DataCollatorForLanguageModeling 로드 성공")
        except ImportError as e:
            logger.error(f"DataCollatorForLanguageModeling 로드 오류: {e}")
            raise ImportError("transformers 라이브러리에서 DataCollatorForLanguageModeling을 로드할 수 없습니다. 'pip install transformers'를 실행하여 최신 버전으로 업그레이드하세요.")
        
        # DataLoader 생성
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer, 
            mlm=False
        )
        
        # collate_fn으로 커스텀 함수 사용하여 BatchEncoding 처리 보장
        def safe_collate_fn(features):
            try:
                # 원래 collator 호출
                batch = data_collator(features)
                
                # 디버깅: 첫 번째 배치만 로깅
                if logger.level <= logging.DEBUG and random.random() < 0.01:  # 1% 샘플링
                    logger.debug(f"배치 타입: {type(batch)}, 키: {list(batch.keys()) if hasattr(batch, 'keys') else 'N/A'}")
                    if hasattr(batch, 'input_ids'):
                        logger.debug(f"input_ids 형태: {batch.input_ids.shape if hasattr(batch.input_ids, 'shape') else 'N/A'}")
                
                return batch
            except Exception as e:
                logger.error(f"Collate 함수 오류: {e}")
                # 긴급 폴백: 원본 특성 그대로 반환
                return features
        
        # 훈련용 데이터 로더 생성 - 상세 정보 로깅
        batch_size = config.get('batch_size', 32)
        num_workers = config.get('dataloader_num_workers', 4)
        logger.info(f"학습 DataLoader 초기화: 배치 크기={batch_size}, 작업자={num_workers}")
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            collate_fn=safe_collate_fn,  # 안전한 collate 함수 사용
            shuffle=True,
            pin_memory=True,
            num_workers=num_workers
        )
        
        # 학습 로더 생성 후 추가 정보
        logger.info(f"학습 DataLoader 생성 완료: 배치 크기={batch_size}, 배치 개수={len(train_loader):,}")
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=config.get('batch_size', 32),
            collate_fn=safe_collate_fn,  # 안전한 collate 함수 사용
            shuffle=False,
            pin_memory=True,
            num_workers=config.get('dataloader_num_workers', 4)
        )
        
        # 로그에 데이터 정보 기록 - 상세 정보 추가
        total_steps = len(train_loader)
        logger.info(f"훈련 데이터 요약 - 학습용: {len(train_dataset):,} 샘플, 검증용: {len(val_dataset):,} 샘플")
        logger.info(f"학습 배치 정보 - 배치당 {batch_size}개 샘플 포함, 총 {total_steps:,}개 배치")
        logger.info(f"학습 예상 시간: 에포크당 최소 {(total_steps*1.5)/60:.1f}분 (배치당 1.5초 가정)")
        
        # 지속학습을 위한 태스크 처리
        # 모드별로 데이터를 별도의 태스크로 간주
        task_id = {'complete': 0, 'prompt': 1, 'comment': 2, 'error_fix': 3}.get(args.mode, 0)
        
        # 테스트 로더는 검증 데이터로 대체 (실제 환경에서는 별도의 테스트 데이터 사용 권장)
        test_loaders = {task_id: val_loader}
        
        # AWS 스팟 인스턴스 중단 감시 시작
        logger.info("AWS 스팟 인스턴스 중단 감시 시작")
        learner.spot_handler.start_monitoring(
            model=learner.model,
            optimizer=learner.optimizer,
            scheduler=learner.scheduler,
            scaler=learner.scaler,
            trainer=learner
        )
        
        # 태스크 훈련 실행
        logger.info(f"태스크 {task_id} ({args.mode} 모드) 훈련 시작")
        learner.train_task(task_id, train_loader, val_loader, test_loaders)
        
        # AWS 스팟 인스턴스 중단 감시 종료
        learner.spot_handler.stop_monitoring()
        
    except Exception as e:
        logger.error(f"데이터 로드 또는 훈련 중 오류 발생: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        sys.exit(1)

    # 최종 메트릭 계산 및 결과 저장
    final_metrics = learner.metrics.compute_metrics()
    results_path = os.path.join(config.get('log_dir', 'runs/continual_learning'), 'final_results.json')
    os.makedirs(os.path.dirname(results_path), exist_ok=True)

    with open(results_path, 'w', encoding='utf-8') as f:
        json.dump({
            'metrics': final_metrics,
            'best_accuracies': learner.best_accuracies,
            'config': config
        }, f, indent=2)

    logger.info(f"최종 결과가 저장되었습니다: {results_path}")
    logger.info(f"최종 메트릭: {final_metrics}")

if __name__ == '__main__':
    main()

0개의 댓글