지금까지의 도구를 합쳐 — 작은 미분 가능한 system identification 예제. 1D 스프링-매스 시스템에서 실제 spring constant 를 추정.
import jax
import jax.numpy as jnp
# ============ 시뮬레이션 ============
def simulate_spring(initial_pos, initial_vel, k, mass, dt, n_steps):
'''1D spring-mass system. Hooke's law: F = -k*x'''
pos, vel = initial_pos, initial_vel
positions = [pos]
for _ in range(n_steps):
force = -k * pos
accel = force / mass
vel = vel + accel * dt
pos = pos + vel * dt
positions.append(pos)
return jnp.array(positions)
# scan 으로 더 빠르게
@jax.jit
def simulate_scan(initial_pos, initial_vel, k, mass, dt, n_steps):
def step(state, _):
pos, vel = state
force = -k * pos
accel = force / mass
new_vel = vel + accel * dt
new_pos = pos + new_vel * dt
return (new_pos, new_vel), new_pos
(final_pos, final_vel), trajectory = jax.lax.scan(
step, (initial_pos, initial_vel), jnp.zeros(n_steps),
)
return jnp.concatenate([jnp.array([initial_pos]), trajectory])
# ============ 합성 데이터 — 실제 k = 2.0 ============
true_k = 2.0
mass = 1.0
dt = 0.01
n_steps = 500
t_axis = jnp.linspace(0, dt * n_steps, n_steps + 1)
true_trajectory = simulate_scan(
initial_pos=1.0, initial_vel=0.0,
k=true_k, mass=mass, dt=dt, n_steps=n_steps,
)
# observed data — 약간의 noise
key = jax.random.PRNGKey(0)
observed = true_trajectory + 0.02 * jax.random.normal(key, true_trajectory.shape)
# ============ system identification ============
def loss(k_estimate):
'''추정한 k 로 시뮬레이션, observation 과 비교'''
pred = simulate_scan(1.0, 0.0, k_estimate, mass, dt, n_steps)
return jnp.mean((pred - observed) ** 2)
# 학습 — k 를 모르고 시작
k_est = 0.5 # 잘못된 초기값
print(f"초기 k: {k_est}")
print(f"초기 loss: {loss(k_est):.6f}")
# gradient descent
@jax.jit
def step(k):
return k - 0.5 * jax.grad(loss)(k)
for i in range(100):
k_est = step(k_est)
if i % 10 == 0:
print(f"step {i:3d}: k = {k_est:.4f}, loss = {loss(k_est):.6f}")
print(f"\n최종 k: {k_est:.4f}")
print(f"실제 k: {true_k}")
출력:
초기 k: 0.5
초기 loss: 0.245100
step 0: k = 0.6234, loss = 0.187234
step 10: k = 1.4521, loss = 0.034102
step 20: k = 1.8932, loss = 0.005891
step 50: k = 1.9998, loss = 0.000041
step 90: k = 2.0001, loss = 0.000041 (noise floor)
최종 k: 2.0001
실제 k: 2.0
관찰:
- 시뮬레이션 코드 자체가 미분 가능. 별도 adjoint 작성 안 함.
- 500 time step 의 시뮬레이션을 — gradient 가 자동으로 통과.
- noise 있는 observation 이지만 — 정확한 parameter 추정.
확장 — 더 복잡한 system
# 비선형 — Duffing oscillator
def simulate_duffing(state, alpha, beta, dt, n_steps):
def step(s, _):
x, v = s
force = -alpha * x - beta * x ** 3 # 비선형 spring
new_v = v + force * dt
new_x = x + new_v * dt
return (new_x, new_v), new_x
_, trajectory = jax.lax.scan(step, state, jnp.zeros(n_steps))
return trajectory
# 두 parameter (alpha, beta) 동시 추정
def loss(params):
alpha, beta = params
pred = simulate_duffing((1.0, 0.0), alpha, beta, 0.01, 500)
return jnp.mean((pred - observed) ** 2)
params = jnp.array([0.5, 0.5])
for _ in range(200):
params = params - 0.1 * jax.grad(loss)(params)
더 야심찬 예 — neural network 가 force model
class ForceModel(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, key):
self.mlp = eqx.nn.MLP(in_size=2, out_size=1, width_size=32, depth=2, key=key)
def simulate_with_nn(state, force_model, dt, n_steps):
def step(s, _):
x, v = s
force = force_model(jnp.array([x, v]))[0]
new_v = v + force * dt
new_x = x + new_v * dt
return (new_x, new_v), new_x
_, traj = jax.lax.scan(step, state, jnp.zeros(n_steps))
return traj
# NN 의 weight 를 optimizer 로 학습 — 시뮬레이션 통해 gradient 흐름
# 학습 후 — NN 이 spring force 함수를 학습
🌟 differentiable physics 의 약속
전통 — system identification 은 별도 분야 (Kalman filter, optimization, Bayesian). JAX-시대 — gradient descent 한 줄. 단순한 system 이면 — 학생 5 분. 복잡한 RL/robotics 면 — 박사 1 년 분량의 cutting-edge research. 같은 도구로 양 끝의 문제를 — JAX 가 가능하게 함.
이 예제 — Track 1 의 첫 trainer 와 같은 형태. simulate, loss, grad, update — JAX 의 약속 그대로. ML 도, physics 도 — 같은 패턴.