C.W.K.
Stream
Lesson 05 of 05 · published

실전: 미분 가능한 Physics 시뮬레이션

~10 min · scientific, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

지금까지의 도구를 합쳐 — 작은 미분 가능한 system identification 예제. 1D 스프링-매스 시스템에서 실제 spring constant 를 추정.

import jax
import jax.numpy as jnp

# ============ 시뮬레이션 ============
def simulate_spring(initial_pos, initial_vel, k, mass, dt, n_steps):
    '''1D spring-mass system. Hooke's law: F = -k*x'''
    pos, vel = initial_pos, initial_vel
    positions = [pos]
    for _ in range(n_steps):
        force = -k * pos
        accel = force / mass
        vel = vel + accel * dt
        pos = pos + vel * dt
        positions.append(pos)
    return jnp.array(positions)

# scan 으로 더 빠르게
@jax.jit
def simulate_scan(initial_pos, initial_vel, k, mass, dt, n_steps):
    def step(state, _):
        pos, vel = state
        force = -k * pos
        accel = force / mass
        new_vel = vel + accel * dt
        new_pos = pos + new_vel * dt
        return (new_pos, new_vel), new_pos

    (final_pos, final_vel), trajectory = jax.lax.scan(
        step, (initial_pos, initial_vel), jnp.zeros(n_steps),
    )
    return jnp.concatenate([jnp.array([initial_pos]), trajectory])

# ============ 합성 데이터 — 실제 k = 2.0 ============
true_k = 2.0
mass = 1.0
dt = 0.01
n_steps = 500

t_axis = jnp.linspace(0, dt * n_steps, n_steps + 1)
true_trajectory = simulate_scan(
    initial_pos=1.0, initial_vel=0.0,
    k=true_k, mass=mass, dt=dt, n_steps=n_steps,
)

# observed data — 약간의 noise
key = jax.random.PRNGKey(0)
observed = true_trajectory + 0.02 * jax.random.normal(key, true_trajectory.shape)

# ============ system identification ============
def loss(k_estimate):
    '''추정한 k 로 시뮬레이션, observation 과 비교'''
    pred = simulate_scan(1.0, 0.0, k_estimate, mass, dt, n_steps)
    return jnp.mean((pred - observed) ** 2)

# 학습 — k 를 모르고 시작
k_est = 0.5   # 잘못된 초기값
print(f"초기 k: {k_est}")
print(f"초기 loss: {loss(k_est):.6f}")

# gradient descent
@jax.jit
def step(k):
    return k - 0.5 * jax.grad(loss)(k)

for i in range(100):
    k_est = step(k_est)
    if i % 10 == 0:
        print(f"step {i:3d}: k = {k_est:.4f}, loss = {loss(k_est):.6f}")

print(f"\n최종 k: {k_est:.4f}")
print(f"실제 k: {true_k}")

출력:

초기 k: 0.5
초기 loss: 0.245100
step   0: k = 0.6234, loss = 0.187234
step  10: k = 1.4521, loss = 0.034102
step  20: k = 1.8932, loss = 0.005891
step  50: k = 1.9998, loss = 0.000041
step  90: k = 2.0001, loss = 0.000041 (noise floor)

최종 k: 2.0001
실제 k: 2.0

관찰:

  • 시뮬레이션 코드 자체가 미분 가능. 별도 adjoint 작성 안 함.
  • 500 time step 의 시뮬레이션을 — gradient 가 자동으로 통과.
  • noise 있는 observation 이지만 — 정확한 parameter 추정.

확장 — 더 복잡한 system

# 비선형 — Duffing oscillator
def simulate_duffing(state, alpha, beta, dt, n_steps):
    def step(s, _):
        x, v = s
        force = -alpha * x - beta * x ** 3   # 비선형 spring
        new_v = v + force * dt
        new_x = x + new_v * dt
        return (new_x, new_v), new_x

    _, trajectory = jax.lax.scan(step, state, jnp.zeros(n_steps))
    return trajectory

# 두 parameter (alpha, beta) 동시 추정
def loss(params):
    alpha, beta = params
    pred = simulate_duffing((1.0, 0.0), alpha, beta, 0.01, 500)
    return jnp.mean((pred - observed) ** 2)

params = jnp.array([0.5, 0.5])
for _ in range(200):
    params = params - 0.1 * jax.grad(loss)(params)

더 야심찬 예 — neural network 가 force model

class ForceModel(eqx.Module):
    mlp: eqx.nn.MLP
    def __init__(self, key):
        self.mlp = eqx.nn.MLP(in_size=2, out_size=1, width_size=32, depth=2, key=key)

def simulate_with_nn(state, force_model, dt, n_steps):
    def step(s, _):
        x, v = s
        force = force_model(jnp.array([x, v]))[0]
        new_v = v + force * dt
        new_x = x + new_v * dt
        return (new_x, new_v), new_x

    _, traj = jax.lax.scan(step, state, jnp.zeros(n_steps))
    return traj

# NN 의 weight 를 optimizer 로 학습 — 시뮬레이션 통해 gradient 흐름
# 학습 후 — NN 이 spring force 함수를 학습

🌟 differentiable physics 의 약속

전통 — system identification 은 별도 분야 (Kalman filter, optimization, Bayesian). JAX-시대 — gradient descent 한 줄. 단순한 system 이면 — 학생 5 분. 복잡한 RL/robotics 면 — 박사 1 년 분량의 cutting-edge research. 같은 도구로 양 끝의 문제를 — JAX 가 가능하게 함.

이 예제 — Track 1 의 첫 trainer 와 같은 형태. simulate, loss, grad, update — JAX 의 약속 그대로. ML 도, physics 도 — 같은 패턴.

Code

import jax
import jax.numpy as jnp

def simulate_projectile(angle, speed, dt=0.01, gravity=9.81,
                         drag_coeff=0.01, num_steps=1000):
    """Simulate a projectile with air drag.

    Returns the trajectory as (x, y) positions.
    The simulation is fully differentiable w.r.t. angle and speed.
    """
    vx0 = speed * jnp.cos(angle)
    vy0 = speed * jnp.sin(angle)

    def step(state, _):
        x, y, vx, vy = state
        # Air drag proportional to velocity squared
        v = jnp.sqrt(vx**2 + vy**2) + 1e-8
        ax = -drag_coeff * v * vx
        ay = -gravity - drag_coeff * v * vy

        vx = vx + ax * dt
        vy = vy + ay * dt
        x = x + vx * dt
        y = y + vy * dt

        # Clamp y >= 0 (ground)
        y = jnp.maximum(y, 0.0)
        return (x, y, vx, vy), (x, y)

    init = (0.0, 0.0, vx0, vy0)
    _, trajectory = jax.lax.scan(step, init, None, length=num_steps)
    return trajectory  # (xs, ys), each shape (num_steps,)

def landing_distance(angle, speed=30.0):
    """Where does the projectile land?"""
    xs, ys = simulate_projectile(angle, speed)
    # Find approximate landing point: last index where y > 0.01
    above_ground = ys > 0.01
    # Use weighted sum as a differentiable approximation
    weights = above_ground.astype(jnp.float32)
    landing_x = jnp.sum(xs * weights) / (jnp.sum(weights) + 1e-8)
    return landing_x

# Gradient: how does landing distance change with launch angle?
d_distance_d_angle = jax.grad(landing_distance)

# Find the angle that hits a target at x=50
target_x = 50.0

def loss(angle):
    return (landing_distance(angle) - target_x) ** 2

# Gradient descent to find optimal angle
angle = jnp.float32(0.7)  # start at ~40 degrees
lr = 0.001

for i in range(200):
    loss_val, grad = jax.value_and_grad(loss)(angle)
    angle = angle - lr * grad
    if i % 50 == 0:
        dist = landing_distance(angle)
        print(f"Step {i}: angle={jnp.degrees(angle):.1f}°, "
              f"distance={dist:.1f}m, loss={loss_val:.4f}")

# Result
optimal_angle = angle
print(f"\\nOptimal angle: {jnp.degrees(optimal_angle):.2f}°")
print(f"Landing distance: {landing_distance(optimal_angle):.2f}m")
print(f"Target: {target_x}m")

# Bonus: sweep many angles with vmap
angles = jnp.linspace(0.1, 1.4, 100)  # 5° to 80°
distances = jax.vmap(landing_distance)(angles)
best_idx = jnp.argmin(jnp.abs(distances - target_x))
print(f"\\nGrid search best: {jnp.degrees(angles[best_idx]):.1f}°")
print(f"Grid search distance: {distances[best_idx]:.1f}m")

External links

Exercise

tiny 1D spring system 작성. integrator 실행. final position 의 spring constant 에 대한 미분. mass 가 target 에 도달하도록 spring tune — JAX 로 gradient-기반 system identification 한 거야.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.