class PreTrainedModel
[PreTrainedModel
]은 모델의 컨피그를 저장하고, 모델을 로딩하고 다운로드하고 저장하는 메서드를 다루며, 인풋 임베딩 사이즈 재정의, 셀프어탠션 헤드 프루닝을 진행합니다.
def init
1.config를 설정해준다.
self.config
_attn_implementation을 확인한다.
flash_attn_2 > sdpa
config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"] "eager" = manual attention implementation
"sdpa" = cls._supports_flash_attn_2로 세팅해주고 cls._check_and_enable_flash_attn_2 ->
"flash_attention_2" = cls._supports_sdpa
cls._check_and_enable_sdpa ->
flash attention은 bf16, fp16에서만 지원한다. 그러면 quantization 모델은?
def post_init(self):