C.W.K.
Stream
Lesson 05 of 06 · published

Expert load balancing 과 expert collapse

~11 min · moe, training, balancing

Level 0Scout
0 XP0/41 lessons0/12 achievements
0/100 XP to next level100 XP to go0% complete

이해해야 할 실패 모드

Expert collapse 가 MoE training 의 가장 큰 단일 실패 모드. Routing 이 unconstrained 면, 네트워크가 빠르게 대부분 토큰을 작은 expert subset 으로 보내고, 그 expert 들이 가장 빨리 개선되고, router 가 훨씬 더 많은 토큰을 그들에게 보내고, ... 루프 보이지. 끝 상태: 몇 expert 가 다 처리, 나머지는 메모리 차지하는 dead weight, benefit 없음.

왜 일어나

Router 가 expert 와 함께 학습. Training 시작에 가장 좋은 expert 가 더 많은 토큰 받고, 더 좋아지고, 더 많은 토큰. Load 분산 명시 압력 없으면 optimizer 가 underused expert 살릴 이유 없음. Catch 못 하면 multi-million-dollar training run 망칠 수 있는 self-reinforcing 실패.

해결책

  • Auxiliary load-balancing loss. Batch 가로질러 대략 동등한 expert 활용 장려하는 penalty 항. Switch Transformer 이후 표준; Mixtral 과 대부분 초기 MoE 사용.
  • Auxiliary-loss-free balancing. DeepSeek-V3 가 도입한 영리한 대안: router logits 에 학습 가능한 per-expert bias 항 추가, training 중 관찰된 expert 활용도 기반 조정. 추가 loss 항 없음, main objective 와 간섭 없음.
  • Expert capacity limit. Single expert 가 batch 당 처리 가능한 토큰 수 cap. 초과 토큰 두 번째 선택 expert 로 routing 또는 drop. Production MoE training 흔함.
  • Noise injection. Training 중 router score 에 noise 추가해서 너무 일찍 fully deterministic routing pattern 형성 방지.

왜 이게 inference 에서도 중요해

Training 후에도 expert 활용 perfectly uniform 안 됨 — 워크로드에 따라 일부 expert 가 진짜 다른 expert 보다 더 발화. Production 에서 GPU 가로질러 load skew 로 나타남: 인기 expert 가진 GPU 가 두들겨 맞고, 다른 GPU 들 반쯤 idle. Production MoE serving stack (vLLM, TensorRT-LLM) 이 동적으로 처리해야 해.

Balancing 전략 위한 model card 읽기

Model card 가 "auxiliary loss" 또는 balancing coefficient 명시하면 표준 길. "auxiliary-loss-free" 또는 "bias-based balancing" 언급하면 DeepSeek-style 디자인 읽고 있는 거. Balancing 전혀 언급 없으면 모델이 dense 거나 누군가 critical detail 잊은 거.

Code

Auxiliary load-balancing loss (Switch Transformer style)·python
import torch

def load_balancing_loss(router_logits, expert_indices, num_experts):
    # router_logits: (B*T, num_experts)
    # expert_indices: (B*T, k) — top-k chosen experts
    fraction_per_expert = torch.zeros(num_experts, device=router_logits.device)
    for k in range(expert_indices.shape[-1]):
        fraction_per_expert.scatter_add_(
            0,
            expert_indices[..., k].view(-1),
            torch.ones_like(expert_indices[..., k].view(-1), dtype=torch.float),
        )
    fraction_per_expert /= expert_indices.shape[0]

    avg_router_prob = torch.softmax(router_logits, dim=-1).mean(dim=0)
    return num_experts * (fraction_per_expert * avg_router_prob).sum()
DeepSeek-V3 style bias 조정 (pseudocode)·python
# Each step: nudge expert_bias toward balancing observed load.
def adjust_expert_bias(expert_bias, observed_load, target_load, lr=1e-3):
    # observed_load: tokens routed to each expert this step
    # target_load:   ideal balanced load
    delta = (target_load - observed_load) * lr
    expert_bias += delta
    return expert_bias
# No backprop, no aux loss term, just a slow-feedback correction.

External links

Exercise

DeepSeek-V3 technical report 의 auxiliary-loss-free balancing 섹션 읽어. 자기 말로 학습된 bias 항 추가가 explicit loss 없이 어떻게 작동하는지 설명. 암묵적 gradient 가 뭐고 어디서 와? 2024 의 더 elegant 한 MoE 디자인 아이디어 중 하나, 완전히 이해할 가치 있어.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.