JAX 가 가장 빛나는 또 한 영역 — 미분 가능한 물리 시뮬레이션. 실제 물리 법칙으로 도는 시뮬레이터를 통해 gradient 가 흘러서 — system identification, control, 강화학습 의 새 패러다임이 가능.
Brax — Google 의 JAX 물리 엔진
pip install brax
from brax import envs
from brax.io import html
# pendulum 환경
env = envs.create("pendulum")
# rollout
state = env.reset(jax.random.PRNGKey(0))
for _ in range(200):
action = jnp.array([0.0]) # zero action
state = env.step(state, action)
Brax 는 — JAX 위에서 도는 rigid-body physics. MuJoCo 같은 환경 다수 제공. GPU 위에서 — 1000 개 환경을 vmap 으로 동시 시뮬레이션 가능. 강화학습 학습 속도 1000x 까지.
# 1000 개 environment, 동시 rollout
envs_batch = jax.vmap(env.step, in_axes=(0, 0))
states = jax.vmap(env.reset)(jax.random.split(key, 1000))
for _ in range(200):
actions = policy(states) # (1000, action_dim)
states = envs_batch(states, actions)
MJX — MuJoCo on JAX
2024 년에 DeepMind 가 출시. MuJoCo 의 JAX 포팅. Brax 보다 더 정확한 contact, friction. RL/robotics 표준 환경 다수.
import mujoco
from mujoco import mjx
model = mujoco.MjModel.from_xml_path("humanoid.xml")
mjx_model = mjx.put_model(model)
@jax.jit
def step(data, action):
data = data.replace(ctrl=action)
return mjx.step(mjx_model, data)
# 1000 개 humanoid simulation 동시
batched_step = jax.vmap(step)
batch_data = jax.vmap(mjx.make_data)(...)
JAX-MD — molecular dynamics
pip install jax-md
from jax_md import space, energy, simulate
# 입자 시뮬레이션
displacement_fn, shift_fn = space.periodic(box_size)
energy_fn = energy.lennard_jones_pair(displacement_fn)
init, apply = simulate.nve(energy_fn, shift_fn, dt=1e-3)
state = init(rng, R, mass=1.0)
# rollout
for _ in range(1000):
state = apply(state)
# 미분 가능 — energy 의 parameter 학습 가능
def loss(epsilon, state):
energy_fn = energy.lennard_jones_pair(displacement_fn, epsilon=epsilon)
new_state = apply_with_energy(energy_fn, state)
return some_objective(new_state)
grad_eps = jax.grad(loss)(epsilon, state)
차이 — 왜 JAX 인가?
| 항목 | 전통 (C++/Python) | JAX-based |
|---|---|---|
| 속도 (CPU) | 빠름 (C++) | 비슷 |
| 속도 (GPU) | 드뭄, 별도 코드 | 자동 — 코드 그대로 |
| 대량 vectorize | 수동 (한 sim 한 thread) | vmap 자동 |
| 미분 | 없거나 hand-coded adjoint | 자동 |
| RL 학습 속도 | baseline | 10-1000x |
RL/robotics 학습 — 환경 시뮬레이션이 병목. JAX-based 는 환경 시뮬레이션이 폭발적으로 빠르고 — agent 학습이 자연스럽게 따라옴.
🌐 산업 / 학술 현황
OpenAI, DeepMind, Google Research 가 — 새 RL 코드 거의 100% JAX-based 시뮬레이터로. 그 영향이 학술 / 스타트업으로 확산 중. 5 년 후엔 — robotics + RL = JAX 가 default 가 될 가능성 높음. 지금 — early adopter 시기.
한 가지 — 전통 시뮬레이터의 정확성이 더 높은 영역도 있어 (복잡한 contact, soft body 등). JAX 시뮬레이터들은 — 빠르게 따라잡는 중. trade-off — 큰 batch + 빠른 학습 vs 정확성. RL 의 거의 모든 use case 에서 — 큰 batch 가 더 중요.