긴 sequence / 깊은 model 학습의 한 가지 큰 적 — backward 에 필요한 forward activation 들이 메모리를 다 잡아먹음. 해결: gradient checkpointing. trade memory for compute.
표준 backprop 의 메모리 패턴
forward: x → h1 → h2 → h3 → h4 → loss
[저장] [저장] [저장] [저장]
backward: 모든 h 사용해서 grad 계산
L-layer 모델 — 모든 layer 의 activation 메모리 보존. seq=4096 transformer layer 24 개 — 메모리 폭발.
checkpoint — 일부 또는 전부 안 저장, backward 에서 재계산
import jax
def expensive_layer(params, x):
# 큰 activation 만드는 layer
h = jnp.tanh(x @ params["W1"])
h = h @ params["W2"]
return h
# checkpoint 적용 — forward 에선 activation 안 저장, backward 에서 재계산
checkpointed = jax.checkpoint(expensive_layer)
def model(params, x):
for layer_params in params:
x = checkpointed(layer_params, x)
return x
# 학습 — 메모리 절감, compute 추가
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
전형적 효과 — Transformer 학습에서 — 메모리 50% 절감 (=2x 큰 batch 가능), compute 33% 증가 (forward 1 번 추가).
granularity — 어디까지 checkpoint?
# 전체 model — 너무 거침
checkpointed_model = jax.checkpoint(model)
# 각 layer — 표준
def model(params, x):
for layer in params:
x = jax.checkpoint(layer_fn)(layer, x)
return x
# 매 N layer 마다 — 더 미세 조정
N = 4
def model(params, x):
for i in range(0, len(params), N):
chunk = params[i:i+N]
x = jax.checkpoint(lambda c, x: chunk_fn(c, x))(chunk, x)
return x
최적 granularity — model 마다 다름. 큰 layer (attention 등) 는 자체 checkpoint, 작은 op 들은 묶기.
policy 를 직접 지정
import jax.checkpoint_policies as ckpt_policies
# 중요한 op (matmul 같은 거) 만 저장, 나머지는 재계산
checkpointed = jax.checkpoint(
expensive_layer,
policy=ckpt_policies.checkpoint_dots_with_no_batch_dims,
)
# 또는 직접
checkpointed = jax.checkpoint(
expensive_layer,
policy=ckpt_policies.dots_saveable,
)
실전: Transformer 학습
def transformer_block(params, x, mask):
'''attention + MLP — 큰 activation'''
h = layer_norm(x, params["ln1"])
h = attention(params["attn"], h, mask)
x = x + h
h = layer_norm(x, params["ln2"])
h = mlp(params["mlp"], h)
x = x + h
return x
# 각 block 마다 checkpoint
def model(params, x, mask):
for block_params in params["blocks"]:
x = jax.checkpoint(transformer_block)(block_params, x, mask)
x = layer_norm(x, params["final_ln"])
return x @ params["head"]
같은 GPU 에서 — 4 배 긴 sequence 학습 가능. trade-off: 학습 속도 ~ 30% 느려짐.
⚖️ 메모리 vs compute
학습 메모리는 — 거의 전부 forward activation. jax.checkpoint 가 그 비용을 시간으로 환전. 큰 모델 / 긴 sequence 일수록 효과 큼. 실전 — 모델이 OOM 나면 — 가장 먼저 시도하는 최적화. ZeRO 같은 sharding 보다 진입 장벽 낮고 효과 즉각적.
JAX 의 jax.remat 는 jax.checkpoint 의 alias. 옛 이름 — 최근 코드는 checkpoint 가 표준.