캐글에서 LoRA 학습해보기
파일 설명
라이브러리 및 패키지 설치
!pip install -q -U keras-nlp
!pip install -q -U keras>=3
import os
# 딥러닝 프레임워크의 백엔드와 메모리 관리를 설정
os.environ["KERAS_BACKEND"] = "jax" # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
🖍️ 첫 번째 설정은 "어떤 자동차를 운전할지" 고르는 것과 같음
🖍️두 번째 설정은 "자동차의 연료 탱크를 얼마나 채울지" 정하는 것과 같음
import keras
import keras_nlp
import json
data = []
with open('/content/databricks-dolly-15k.jsonl') as file:
for line in file:
features = json.loads(line)
# Filter out examples with context, to keep it simple.
if features["context"]:
continue
# Format the entire example as a single string.
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
data.append(template.format(**features))
#1000개의 데이터만 뽑아서 사용
# Only use 1000 training examples, to keep it fast.
data = data[:1000]
케글API
!mkdir -p ~/.kaggle
!echo '{"username":"taixi1992","key":"사용자"}' > ~/.kaggle/kaggle.json
!chmod 600 ~/.kaggle/kaggle.jso
모델 다운 및 로딩
import kagglehub
model = kagglehub.model_download("keras/gemma/keras/gemma_2b_en/2")
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()
fine tuning 전
prompt = template.format(
instruction="What should I do on a trip to Europe?",
response="",
)
print(gemma_lm.generate(prompt, max_length=256))
prompt = template.format(
instruction="Explain the process of photosynthesis in a way that a child could understand.",
response="",
)
print(gemma_lm.generate(prompt, max_length=256))
LoRA Fine-tuning
gemma_lm.backbone.enable_lora(rank=8)
gemma_lm.summary()
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
LoRA Fine-tuning 결과
# Rank(4)
prompt = template.format(
instruction="What should I do on a trip to Europe?",
response="",
)
print(gemma_lm.generate(prompt, max_length=256))
# Rank(8)
prompt = template.format(
instruction="What should I do on a trip to Europe?",
response="",
)
print(gemma_lm.generate(prompt, max_length=256))
# Rank(4)
prompt = template.format(
instruction="Explain the process of photosynthesis in a way that a child could understand.",
response="",
)
print(gemma_lm.generate(prompt, max_length=256))
# Rank(8)
prompt = template.format(
instruction="Explain the process of photosynthesis in a way that a child could understand.",
response="",
)
print(gemma_lm.generate(prompt, max_length=256))
랭크4보다는 랭크8로 변경시 결과가 더 좋은거 같다
💡 높은 LoRA 순위는 모델의 세부 조정과 성능 향상에 효과적이며, 복잡한 작업과 데이터셋에서 특히 유용함. 하지만 리소스와 성능 간의 균형을 고려해야 함.