C.W.K.
Stream
Lesson 02 of 06 · published

Dtype: float32, bfloat16, 그리고 정확도의 함정

~10 min · numpy, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX 의 default dtype 은 float32. NumPy 는 float64. 이 차이를 모르면 — 같은 코드인데 답이 다르게 나오는 일 생겨.

import numpy as np
import jax.numpy as jnp

a_np = np.array([1.0, 2.0, 3.0])
a_jax = jnp.array([1.0, 2.0, 3.0])

print(a_np.dtype)   # float64
print(a_jax.dtype)  # float32  (!)

왜 float32 가 default 냐 — 가속기 (GPU/TPU) 에서 float32 가 보통 float64 의 2~30 배 빨라. ML workload 는 거의 다 float32 로 충분히 수렴해.

import jax
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
a = jnp.array([1.0, 2.0])
print(a.dtype)  # float64 (이제 됨)

⚠️ 정밀도 함정

NumPy 에서 float64 로 동작하던 코드를 JAX 로 옮기면 — 같은 답이 안 나올 수 있어. 학습이 안 되거나 NaN 나면 가장 먼저 dtype 체크.

Code

import jax.numpy as jnp

# Standard dtypes
f32 = jnp.array(1.0, dtype=jnp.float32)   # 32-bit float (default)
f64 = jnp.array(1.0, dtype=jnp.float64)   # 64-bit float
i32 = jnp.array(1, dtype=jnp.int32)       # 32-bit integer
b = jnp.array(True, dtype=jnp.bool_)      # Boolean

# ML-specific dtypes
f16 = jnp.array(1.0, dtype=jnp.float16)   # 16-bit float (half precision)
bf16 = jnp.array(1.0, dtype=jnp.bfloat16) # Brain floating point 16
import jax.numpy as jnp

# float16 has limited range
try:
    big_f16 = jnp.array(100000.0, dtype=jnp.float16)
    print(f"float16: {big_f16}")  # inf — overflows!
except:
    pass

# bfloat16 handles the same value fine
big_bf16 = jnp.array(100000.0, dtype=jnp.bfloat16)
print(f"bfloat16: {big_bf16}")  # 99840.0 — less precise but doesn't overflow

# float32 for reference
big_f32 = jnp.array(100000.0, dtype=jnp.float32)
print(f"float32: {big_f32}")    # 100000.0

# Memory usage: half of float32
arr_f32 = jnp.ones((1000, 1000), dtype=jnp.float32)
arr_bf16 = jnp.ones((1000, 1000), dtype=jnp.bfloat16)
print(f"float32 size: {arr_f32.nbytes / 1e6:.1f} MB")  # 4.0 MB
print(f"bfloat16 size: {arr_bf16.nbytes / 1e6:.1f} MB") # 2.0 MB
import jax
jax.config.update("jax_enable_x64", True)

# Now float64 is available
import jax.numpy as jnp
x = jnp.array(1.0, dtype=jnp.float64)
print(x.dtype)  # float64

External links

Exercise

같은 dot product 를 float32, float64 (x64 enable), bfloat16 으로 계산. 결과, max abs error, time 출력. bfloat16 이 답이 무너지는 input 한 개 찾기. 실제 학습 코드에서 어떻게 할지 적기.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.