C.W.K.
Stream
Lesson 02 of 05 · published

Gradient Checkpointing: jax.checkpoint / jax.remat

~10 min · advanced, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

긴 sequence / 깊은 model 학습의 한 가지 큰 적 — backward 에 필요한 forward activation 들이 메모리를 다 잡아먹음. 해결: gradient checkpointing. trade memory for compute.

표준 backprop 의 메모리 패턴

forward:  x → h1 → h2 → h3 → h4 → loss
              [저장]  [저장]  [저장]  [저장]
backward: 모든 h 사용해서 grad 계산

L-layer 모델 — 모든 layer 의 activation 메모리 보존. seq=4096 transformer layer 24 개 — 메모리 폭발.

checkpoint — 일부 또는 전부 안 저장, backward 에서 재계산

import jax

def expensive_layer(params, x):
    # 큰 activation 만드는 layer
    h = jnp.tanh(x @ params["W1"])
    h = h @ params["W2"]
    return h

# checkpoint 적용 — forward 에선 activation 안 저장, backward 에서 재계산
checkpointed = jax.checkpoint(expensive_layer)

def model(params, x):
    for layer_params in params:
        x = checkpointed(layer_params, x)
    return x

# 학습 — 메모리 절감, compute 추가
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)

전형적 효과 — Transformer 학습에서 — 메모리 50% 절감 (=2x 큰 batch 가능), compute 33% 증가 (forward 1 번 추가).

granularity — 어디까지 checkpoint?

# 전체 model — 너무 거침
checkpointed_model = jax.checkpoint(model)

# 각 layer — 표준
def model(params, x):
    for layer in params:
        x = jax.checkpoint(layer_fn)(layer, x)
    return x

# 매 N layer 마다 — 더 미세 조정
N = 4
def model(params, x):
    for i in range(0, len(params), N):
        chunk = params[i:i+N]
        x = jax.checkpoint(lambda c, x: chunk_fn(c, x))(chunk, x)
    return x

최적 granularity — model 마다 다름. 큰 layer (attention 등) 는 자체 checkpoint, 작은 op 들은 묶기.

policy 를 직접 지정

import jax.checkpoint_policies as ckpt_policies

# 중요한 op (matmul 같은 거) 만 저장, 나머지는 재계산
checkpointed = jax.checkpoint(
    expensive_layer,
    policy=ckpt_policies.checkpoint_dots_with_no_batch_dims,
)

# 또는 직접
checkpointed = jax.checkpoint(
    expensive_layer,
    policy=ckpt_policies.dots_saveable,
)

실전: Transformer 학습

def transformer_block(params, x, mask):
    '''attention + MLP — 큰 activation'''
    h = layer_norm(x, params["ln1"])
    h = attention(params["attn"], h, mask)
    x = x + h

    h = layer_norm(x, params["ln2"])
    h = mlp(params["mlp"], h)
    x = x + h
    return x

# 각 block 마다 checkpoint
def model(params, x, mask):
    for block_params in params["blocks"]:
        x = jax.checkpoint(transformer_block)(block_params, x, mask)
    x = layer_norm(x, params["final_ln"])
    return x @ params["head"]

같은 GPU 에서 — 4 배 긴 sequence 학습 가능. trade-off: 학습 속도 ~ 30% 느려짐.

⚖️ 메모리 vs compute

학습 메모리는 — 거의 전부 forward activation. jax.checkpoint 가 그 비용을 시간으로 환전. 큰 모델 / 긴 sequence 일수록 효과 큼. 실전 — 모델이 OOM 나면 — 가장 먼저 시도하는 최적화. ZeRO 같은 sharding 보다 진입 장벽 낮고 효과 즉각적.

JAX 의 jax.rematjax.checkpoint 의 alias. 옛 이름 — 최근 코드는 checkpoint 가 표준.

Code

import jax
import jax.numpy as jnp

# Without checkpointing: stores all intermediate activations
def forward(params, x):
    for layer in params:
        x = jax.nn.relu(x @ layer['w'] + layer['b'])
    return x

# With checkpointing: recomputes activations during backward pass
@jax.checkpoint  # or equivalently, @jax.remat
def forward_checkpointed(params, x):
    for layer in params:
        x = jax.nn.relu(x @ layer['w'] + layer['b'])
    return x

# Same output, same gradients, but uses much less memory
grads = jax.grad(lambda p: jnp.sum(forward_checkpointed(p, x)))(params)
from flax import nnx
import jax

class TransformerBlock(nnx.Module):
    def __init__(self, d_model, num_heads, d_ff, rngs):
        self.attention = nnx.MultiHeadAttention(
            num_heads=num_heads, in_features=d_model,
            qkv_features=d_model, out_features=d_model, rngs=rngs)
        self.ff1 = nnx.Linear(d_model, d_ff, rngs=rngs)
        self.ff2 = nnx.Linear(d_ff, d_model, rngs=rngs)
        self.ln1 = nnx.LayerNorm(d_model, rngs=rngs)
        self.ln2 = nnx.LayerNorm(d_model, rngs=rngs)

    @nnx.jit
    def __call__(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.ff2(nnx.gelu(self.ff1(self.ln2(x))))
        return x

# Checkpoint individual blocks
class Transformer(nnx.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, rngs):
        self.blocks = [
            TransformerBlock(d_model, num_heads, d_ff, rngs)
            for _ in range(num_layers)
        ]

    def __call__(self, x):
        for block in self.blocks:
            # Remat each block: activations are recomputed in backward
            x = jax.checkpoint(lambda b, x: b(x), block, x)
        return x
# You can also use remat with a policy for fine-grained control
from jax.ad_checkpoint import checkpoint_policies

# Only save certain operations (e.g., dots but not norms)
policy = checkpoint_policies.save_only_these_names('dot_general')

@jax.remat(policy=policy)
def block_with_policy(params, x):
    # Only the results of matrix multiplications are saved;
    # everything else is recomputed
    return transformer_block(params, x)

External links

Exercise

4-layer MLP 를 합성 data 에 학습. 각 layer 의 forward 를 jax.checkpoint 로 wrap. before vs after 의 peak memory 측정. trade — recompute time for memory — 가 long-context Transformer 가 single GPU 에 들어가는 이유.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.