C.W.K.
Stream
Lesson 08 of 08 · published

SFTTrainer, Chat Template, Merging

~26 min · training, trl, sft

Level 0스카우트
0 XP0/50 lessons0/10 achievements
0/120 XP to next level120 XP to go0% complete

SFTTrainer 가 instruction tuning 의 right Trainer

trl.SFTTrainer 가 chat-format 데이터 위해 Trainer 를 편리하게 wrap:

  • messages 컬럼 데이터셋 직접 받음.
  • 토크나이저의 chat template 자동 적용.
  • response_template 주면 DataCollatorForCompletionOnlyLM 디폴트 셋업.
  • 인자 두 개 추가하면 PEFT (LoRA / QLoRA) 와 plug.

데이터셋 포맷

둘 중 하나 동작:

  • Conversational: 각 행이 messages: [{"role":"user","content":...}, {"role":"assistant","content":...}].
  • Single-turn text: 각 행이 이미 포맷된 prompt + response 담은 text 컬럼.

Conversational 이 모던 디폴트. Trainer 가 모델 chat template 적용; string 손으로 포맷 X.

Code

SFTTrainer + LoRA + 챗 데이터셋·python
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
import torch

ds = load_dataset("HuggingFaceH4/no_robots", split="train")  # conversational

base = "Qwen/Qwen2.5-1.5B-Instruct"
tok = AutoTokenizer.from_pretrained(base)
model = AutoModelForCausalLM.from_pretrained(base, torch_dtype=torch.bfloat16)

lora_cfg = LoraConfig(
    r=16, lora_alpha=32,
    target_modules="all-linear",
    bias="none", task_type="CAUSAL_LM",
)

cfg = SFTConfig(
    output_dir="./sft-out",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=20,
    save_steps=500,
    max_seq_length=2048,
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tok,
    args=cfg,
    train_dataset=ds,
    peft_config=lora_cfg,
)
trainer.train()
trainer.save_model("./sft-adapter")
merged 모델 + tokenizer + adapter 를 Hub 에·python
from peft import AutoPeftModelForCausalLM

m = AutoPeftModelForCausalLM.from_pretrained("./sft-adapter")
merged = m.merge_and_unload()
merged.save_pretrained("./qwen-sft-merged")
tok.save_pretrained("./qwen-sft-merged")

# Push (private)
merged.push_to_hub("yourname/qwen-sft", private=True)
tok.push_to_hub("yourname/qwen-sft", private=True)

External links

Exercise

작은 챗 데이터셋 (HuggingFaceH4/no_robots 또는 너 JSONL). 1B-3B 모델을 SFTTrainer + LoRA 로 SFT. Adapter 저장. 인퍼런스 테스트: unmerged 모델 (base + adapter), merged 모델. 출력 비슷해야 함.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.