C.W.K.
Stream
Lesson 04 of 05 · published

Physics Simulation 과 JAX 생태계

~10 min · scientific, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX 가 가장 빛나는 또 한 영역 — 미분 가능한 물리 시뮬레이션. 실제 물리 법칙으로 도는 시뮬레이터를 통해 gradient 가 흘러서 — system identification, control, 강화학습 의 새 패러다임이 가능.

Brax — Google 의 JAX 물리 엔진

pip install brax
from brax import envs
from brax.io import html

# pendulum 환경
env = envs.create("pendulum")

# rollout
state = env.reset(jax.random.PRNGKey(0))
for _ in range(200):
    action = jnp.array([0.0])  # zero action
    state = env.step(state, action)

Brax 는 — JAX 위에서 도는 rigid-body physics. MuJoCo 같은 환경 다수 제공. GPU 위에서 — 1000 개 환경을 vmap 으로 동시 시뮬레이션 가능. 강화학습 학습 속도 1000x 까지.

# 1000 개 environment, 동시 rollout
envs_batch = jax.vmap(env.step, in_axes=(0, 0))
states = jax.vmap(env.reset)(jax.random.split(key, 1000))

for _ in range(200):
    actions = policy(states)   # (1000, action_dim)
    states = envs_batch(states, actions)

MJX — MuJoCo on JAX

2024 년에 DeepMind 가 출시. MuJoCo 의 JAX 포팅. Brax 보다 더 정확한 contact, friction. RL/robotics 표준 환경 다수.

import mujoco
from mujoco import mjx

model = mujoco.MjModel.from_xml_path("humanoid.xml")
mjx_model = mjx.put_model(model)

@jax.jit
def step(data, action):
    data = data.replace(ctrl=action)
    return mjx.step(mjx_model, data)

# 1000 개 humanoid simulation 동시
batched_step = jax.vmap(step)
batch_data = jax.vmap(mjx.make_data)(...)

JAX-MD — molecular dynamics

pip install jax-md
from jax_md import space, energy, simulate

# 입자 시뮬레이션
displacement_fn, shift_fn = space.periodic(box_size)
energy_fn = energy.lennard_jones_pair(displacement_fn)

init, apply = simulate.nve(energy_fn, shift_fn, dt=1e-3)
state = init(rng, R, mass=1.0)

# rollout
for _ in range(1000):
    state = apply(state)

# 미분 가능 — energy 의 parameter 학습 가능
def loss(epsilon, state):
    energy_fn = energy.lennard_jones_pair(displacement_fn, epsilon=epsilon)
    new_state = apply_with_energy(energy_fn, state)
    return some_objective(new_state)

grad_eps = jax.grad(loss)(epsilon, state)

차이 — 왜 JAX 인가?

항목전통 (C++/Python)JAX-based
속도 (CPU)빠름 (C++)비슷
속도 (GPU)드뭄, 별도 코드자동 — 코드 그대로
대량 vectorize수동 (한 sim 한 thread)vmap 자동
미분없거나 hand-coded adjoint자동
RL 학습 속도baseline10-1000x

RL/robotics 학습 — 환경 시뮬레이션이 병목. JAX-based 는 환경 시뮬레이션이 폭발적으로 빠르고 — agent 학습이 자연스럽게 따라옴.

🌐 산업 / 학술 현황

OpenAI, DeepMind, Google Research 가 — 새 RL 코드 거의 100% JAX-based 시뮬레이터로. 그 영향이 학술 / 스타트업으로 확산 중. 5 년 후엔 — robotics + RL = JAX 가 default 가 될 가능성 높음. 지금 — early adopter 시기.

한 가지 — 전통 시뮬레이터의 정확성이 더 높은 영역도 있어 (복잡한 contact, soft body 등). JAX 시뮬레이터들은 — 빠르게 따라잡는 중. trade-off — 큰 batch + 빠른 학습 vs 정확성. RL 의 거의 모든 use case 에서 — 큰 batch 가 더 중요.

Code

# JAX-MD provides differentiable molecular dynamics simulations
# Example concept (simplified):
import jax
import jax.numpy as jnp

def lennard_jones(r, epsilon=1.0, sigma=1.0):
    """Lennard-Jones potential between two particles at distance r."""
    s6 = (sigma / r) ** 6
    return 4.0 * epsilon * (s6**2 - s6)

# The potential is differentiable — get forces automatically
force = -jax.grad(lennard_jones)
print(f"Force at r=1.5: {force(1.5):.4f}")

# Vectorize over all pairs
def total_energy(positions):
    """Compute total LJ energy for a system of particles."""
    n = positions.shape[0]
    energy = 0.0
    for i in range(n):
        for j in range(i + 1, n):
            r = jnp.linalg.norm(positions[i] - positions[j])
            energy += lennard_jones(r)
    return energy

# Get forces on all particles: F = -∇E
forces = -jax.grad(total_energy)(positions)
# Brax enables training robot controllers through physics
# The entire simulation is differentiable:
# environment_step → reward → grad(policy_params)
# This is vastly faster than finite-difference methods used in
# traditional RL for continuous control.
# JAXopt provides scipy-style optimization with JAX
import jaxopt

# Minimize a function with L-BFGS
def rosenbrock(x):
    return (1 - x[0])**2 + 100 * (x[1] - x[0]**2)**2

solver = jaxopt.LBFGS(fun=rosenbrock, maxiter=100)
result = solver.run(jnp.array([0.0, 0.0]))
print(f"Minimum at: {result.params}")  # ≈ [1.0, 1.0]

# Projected gradient descent (constrained optimization)
solver = jaxopt.ProjectedGradient(
    fun=objective,
    projection=jaxopt.projection.projection_box,
    maxiter=200,
)
import jax
import jax.numpy as jnp

# Compute Taylor coefficients of a function
# This is useful for solving PDEs with high-order derivatives
def f(x):
    return jnp.sin(x) * jnp.exp(-x)

# Get the 4th derivative efficiently
from jax import jet

# jet computes Taylor-mode derivatives
primals = jnp.array(1.0)
series = (jnp.array(1.0), jnp.array(0.0), jnp.array(0.0), jnp.array(0.0))
primal_out, series_out = jax.jet(f, (primals,), (series,))
# series_out contains higher-order derivatives
# This scales linearly with derivative order, not exponentially

External links

Exercise

Brax pendulum 환경 spin up. 1 rollout. 그 다음 trajectory 의 final cost 의 initial-state parameter 에 대한 미분. End-to-end differentiable physics — 한 번 보면 — 응용을 못 잊음.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.