포스트

[AI 말평 대회] 참여기 #11: 3주차(3) - GRPO 학습 구현

[AI 말평 대회] 참여기 #11: 3주차(3) - GRPO 학습 구현

11. GRPO 학습 구현

저번에 개념과 보상 아이디어를 정리했다면, 이번 편은 실제 구현이다.

팀원분의 조언에 따라 콜백(CallBack)을 통해 저장 시점마다 자동 검증 추론 → JSON 저장 → 한 번에 스코어링 → 최고 성능 체크포인트 선택 파이프라인을 구성했다.


1) 데이터/프롬프트 파이프라인

  • CustomDataset은 입력 JSON에서 questionanswer를 읽고, 지시문(instruct)과 few-shot 예시를 합쳐 인퍼런스용 입력을 구성한다.
  • few-shot은 유형별 샘플링(Category_FewShotGenerater) 을 사용해 선택형/교정형을 구분해 넣었다.
1
2
3
4
5
6
7
8
9
# 데이터셋/지시문 준비 (요약)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
few_shot = Category_FewShotGenerater(args.few_shot_data_path, num_few_shot=args.num_few_shot_data)

with open(args.instruct_path, 'r', encoding='utf-8') as f:
    instruct = json.load(f)

train_dataset = CustomDataset(args.train_data_path, tokenizer, instruct, few_shot_generater=few_shot)
eval_dataset  = CustomDataset(args.eval_data_path,  tokenizer, instruct, few_shot_generater=few_shot)

2) 보상 함수(Reward) 구성

대회 지표를 그대로 보상으로 쓴다. Format(“…이/가 옳다.”), EM, ROUGE‑1, BERTScore 를 조합한다.

1
2
3
4
5
6
7
def multi_reward(prompts, completions, references=None, **kwargs):
    r_bert  = BERTScore_reward_fn(prompts, completions, references)
    r_em    = EM_reward_fn(prompts, completions, references)
    r_form  = format_reward_fn(prompts, completions, references)
    r_rouge = ROUGE_1_reward_fn(prompts, completions, references)
    # 조합: (BERT + EM + (Format + 3*ROUGE)/2) / 3
    return [(b + e + (f + 3*r)/2) / 3 for b, e, f, r in zip(r_bert, r_em, r_form, r_rouge)]

포맷 보상은 정답 문장 형식을 정규식으로 체크하고, EM은 정답문장 일치(여러 허용 정답 # 분리)로 0/1 스코어,
ROUGE‑1/BERTScore는 이유(reason) 부분의 유사도를 평가한다.


3) GRPO 설정과 콜백 등록

학습은 TRL의 GRPOTrainer 를 사용하고, 저장 타이밍마다 2개의 콜백이 동작한다.

  • LoraSaveCallback: LoRA 어댑터를 output_dir/epoch_{E}에 저장
  • InferenceCallback: 방금 저장된 가중치를 즉시 로드해 검증셋에 추론inference_epoch{E}.json 저장
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
grpo_config = GRPOConfig(
    num_generations=args.num_generations,
    output_dir=args.output_dir,
    num_train_epochs=args.num_train_epochs,
    per_device_train_batch_size=args.per_device_train_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    learning_rate=args.learning_rate,
    logging_steps=args.logging_steps,
    save_strategy=args.save_strategy,    # "epoch" 권장
    eval_strategy=args.eval_strategy,    # "no" 또는 "epoch"
    load_best_model_at_end=False,        # 모델 선택은 별도 스크립트로
    remove_unused_columns=False,
    report_to="wandb"
)

trainer = GRPOTrainer(
    model=base_model,
    args=grpo_config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    reward_funcs=[multi_reward],
    callbacks=[LoraSaveCallback(), inference_cb]  # 아래 InferenceCallback 참고
)

4) 메모리 관리

GRPO는 한 프롬프트에서 여러 응답을 생성하기 때문에, 메모리 사용량이 num_generations(G)선형으로 늘어난다.

기호

  • B = per_device_train_batch_size
  • G = num_generations
  • A = gradient_accumulation_steps
  • D = GPU 개수(world size)

효과적 배치 크기

  • 한 옵티마이저 스텝 총 응답 수: B × G × A × D
  • 한 옵티마이저 스텝 프롬프트 수: (B ÷ G) × A × D

⚠️ B는 반드시 G의 배수여야 한다.

정해진 GPU VRAM 제한(24GB) 내에서 학습을 돌리느라 OOM을 정말 많이 본것 같다.

그 때문에 양자화와 LoRA등 정말 다양한 메모리 절약 방식을 적용하면서 많은걸 배웠다.

학습을 돌리면서 이 부분을 특히 주의해서 파라미터 구성을 했던 경험이 생각나 공유해본다.


🔧 양자화 + LoRA 적용 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

# --- 양자화 설정 ---
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype="bfloat16",   # bf16 연산
    bnb_4bit_quant_type="nf4",           # NormalFloat4
    bnb_4bit_use_double_quant=True
)

# --- 모델 로드 (4bit 양자화 적용) ---
base_model = AutoModelForCausalLM.from_pretrained(
    args.model_name,
    quantization_config=bnb_config,
    device_map="auto"
)

# --- LoRA 설정 ---
lora_config = LoraConfig(
    r=64,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# --- LoRA 어댑터 부착 ---
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

위 코드에서:

  • 양자화: 4bit(nf4)로 메모리 절약, 연산은 bfloat16
  • LoRA: q/k/v/o projection 층만 학습 → 파라미터/VRAM 최소화

이런 설정으로 24GB VRAM에서도 K-intelligence/Midm-2.0-Base-Instruct같은 큰 모델까지 학습이 가능했다.


5) InferenceCallback: 저장 직후, 검증셋 자동 추론

학습이 저장(save) 될 때마다 다음을 수행한다.

1) 가장 최근 LoRA 어댑터 폴더(예: epoch_3/) 경로를 구성
2) 해당 가중치를 붙여 검증셋에 generate
3) JSON(..._epoch3.json) 파일로 결과를 기록

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class InferenceCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        epoch = int(state.epoch)
        lora_path = os.path.join(args.output_dir, f"epoch_{epoch}")
        run_inference(
            input_path=args.eval_data_path,
            output_path=args.inference_output_path.replace(".json", f"_epoch{epoch}.json"),
            model_id=args.model_name,
            tokenizer_id=args.model_name,
            lora_weights_path=lora_path,  # 최신 LoRA
            prompt=instruct["prompt"],
            correction_prompt=instruct["correction_prompt"],
            selection_prompt=instruct["selection_prompt"],
            quant=args.quant_type if args.use_quant else None,
            few_shot_generater=few_shot
        )

run_inference양자화 옵션/LoRA 부착/종료 토큰 설정 후, few‑shot 프롬프트로 generate를 호출해 각 문제의 답변과 정규화된 답변(normalized) 을 JSON으로 쓴다.

1
2
3
4
5
6
7
8
9
10
11
12
13
def run_inference(...):
    model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quant_cfg, ...)
    if lora_weights_path:
        model = PeftModel.from_pretrained(model, lora_weights_path)
    model.eval()

    dataset = CustomDataset(input_path, tokenizer, instruct=..., few_shot_generater=few_shot)
    for i in range(len(dataset)):
        outputs = model.generate(...)
        text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
        text = strip_prefixes_and_normalize(text)
        result[i]["output"] = {"answer": text, "normalized_answer": normalize_quotes(text)}
    json.dump(result, open(output_path, "w"), ensure_ascii=False, indent=4)

6) 한 번에 평가하고 최고 체크포인트 선택

에폭마다 생긴 inference_epoch{E}.json 들을 한 디렉토리에 모아 아래 스크립트로 평가한다.
평가는 ROUGE‑1, BERTScore, EM 을 평균해 Final Mean을 만들고, 가장 높은 파일이 곧 Best Checkpoint다.

1
2
python evaluate_json.py   --input_dir outputs/infer_results/ \  # *.json들이 들어있는 폴더
  --label_path data/labels_eval.json

출력 예시:

1
2
3
4
5
6
7
8
9
File: inference_epoch3.json
   - ROUGE-1        : 0.4132
   - BERTScore      : 0.6871
   - EM             : 0.5528
   - Final Mean     : 0.5510

Best Final Mean:
   File       : inference_epoch5.json
   Final Mean : 0.5742

7) 마무리

결론적으로 현재 학습 파이프라인에서 콜백으로 학습‑저장‑검증‑선택 루프를 자동화하여, 사람 손을 거의 쓰지 않고 가장 좋은 체크포인트를 고를 수 있었다.

다음 편은 학습을 돌려본 결과에 대해 다뤄볼 예정이다.

이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.