C.W.K.
Stream
Lesson 04 of 05 · published

JAX AI Stack 과 Pallas Custom Kernel

~8 min · ecosystem, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

2024 년 Google 이 JAX AI Stack 발표 — JAX 기반 ML 도구의 통합 패키지. Flax NNX (model), Optax (optimizer), Orbax (checkpoint), Grain (data loader), 기타. 한 번에 설치:

pip install jax-ai-stack

새 프로젝트 시작 시 — 이 한 줄로 verified-compatible version 의 라이브러리 묶음 설치. 버전 충돌 안 남.

Pallas — JAX 의 custom kernel 언어

표준 JAX op 으로 표현 안 되는 — 매우 특수한 hardware 친화 kernel. JAX 의 PyTorch Triton 비슷.

from jax.experimental import pallas as pl

@pl.pallas_call(out_shape=jax.ShapeDtypeStruct((1024,), jnp.float32))
def add_kernel(x_ref, y_ref, out_ref):
    '''단순 vector add — Pallas 식으로'''
    x = x_ref[...]
    y = y_ref[...]
    out_ref[...] = x + y

# 사용 — 표준 JAX 처럼
x = jnp.ones(1024)
y = jnp.ones(1024)
result = add_kernel(x, y)

위 — single thread block 의 vector add. 실전 — 더 복잡한 fused kernel, 특히 attention 변형:

# Flash Attention — Pallas 로 구현 (의사 코드)
@pl.pallas_call(...)
def flash_attention_kernel(q_ref, k_ref, v_ref, out_ref):
    # block-wise loading
    # softmax 의 online algorithm
    # weighted sum 를 in-place 누적
    ...

왜 Pallas 가 필요한가:

  • memory hierarchy 활용 — HBM ↔ SRAM 의 명시적 제어. attention 의 IO 비용이 메모리 hierarchy 에 따라 결정.
  • fusion 의 한계 돌파 — XLA 의 자동 fusion 이 안 잡는 패턴.
  • 새 hardware 활용 — TPU 의 sparse core, GPU 의 tensor core 의 직접 활용.

실전 — Pallas 가 빛나는 곳

분야Pallas 의 역할
Flash Attention긴 sequence 의 attention 메모리 감소
Sparse layerssparse matrix multiplication 의 효율 kernel
Quantizationint8 / fp8 행렬 연산
MoE (Mixture of Experts)scatter/gather 의 효율 패턴
LSH / approximate ops해시 기반 sampling, near-neighbor

대부분의 사용자 — Pallas 가 필요 없어. 표준 JAX op 으로 충분. Pallas 가 결정적인 경우 — 매우 큰 모델, 특수 hardware 활용, 새 algorithm research.

JAX 진영의 다른 보조 라이브러리

  • chex — 단위 test, 디버깅, dataclass
  • clu (Common Loop Utils) — metric 누적, summary writer
  • einshape / einops — tensor shape 조작
  • xpilot — distributed coordinator (Google 내부 → 외부 release)
  • distrax — probability distribution (DeepMind)

🛠 production stack 추천

(1) framework: Flax NNX 또는 Equinox. (2) optimizer: Optax. (3) checkpoint: Orbax. (4) data: Grain (또는 PyTorch DataLoader). (5) logging: wandb. (6) test: chex. (7) 기타: jax-ai-stack 의 의존성. Pallas 는 — 표준 op 이 막힐 때만. 이 stack 으로 — Llama-class 모델 학습 가능.

Pallas 의 위치 — 비유하자면 JAX 의 inline assembly. 99% 의 코드는 안 쓰지만 — 그 1% 가 결정적일 수 있어. 알아둘 만한 정도.

Code

# Install the whole stack at once
# pip install jax-ai-stack

# This gives you:
# - jax            (core)
# - flax           (neural networks)
# - optax          (optimizers)
# - orbax          (checkpointing)
# - ml_dtypes      (bfloat16, etc.)
# - grain          (data loading)
# - chex           (testing utilities)

import jax
from flax import nnx
import optax
import orbax.checkpoint as ocp

# Version check (early 2026):
# jax ~0.9.x, flax ~0.12.x, optax ~0.2.x
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

# A simple Pallas kernel: vector addition
def add_kernel(x_ref, y_ref, o_ref):
    """Pallas kernel: reads from x and y, writes to o."""
    o_ref[...] = x_ref[...] + y_ref[...]

# Launch the kernel
def pallas_add(x, y):
    return pl.pallas_call(
        add_kernel,
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
    )(x, y)

x = jnp.ones(1024)
y = jnp.ones(1024) * 2
result = pallas_add(x, y)
print(result[:5])  # [3. 3. 3. 3. 3.]

External links

Exercise

Pallas overview 읽기. full kernel 안 짜도 됨 — 추상화만 이해 (JAX 의 Triton-like). DeepMind/Google research 코드의 Pallas 사용 kernel 1 개 찾고 — 무엇을 하는지 3 줄 메모. 거기 있다는 걸 아는 정도면 — 필요할 때까지 충분.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.