C.W.K.
Stream
Lesson 02 of 04 · published

tf.distribute strategies

~13 min · mirrored, tpu-strategy, multi-worker

Level 0Level 0
0 XP0/78 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

두 줄 코드로 multi-GPU 또는 TPU

TF distribution API가 단일 device 스크립트를 최소 코드 변경으로 multi-GPU나 TPU chip에 걸쳐 실행. Strategy 만들고 strategy.scope() 안에서 build + compile, model.fit이 gradient 동기화 자동 처리.

Strategy하드웨어Sync상태
MirroredStrategy1 머신, N GPUSync (NCCL)Stable
TPUStrategyTPU pod / v2/v3/v4SyncStable
MultiWorkerMirroredStrategyN 머신 × N GPUSyncStable
ParameterServerStrategyN worker + param 서버AsyncExperimental

Replica 수에 맞춰 global batch size 스케일링. 4 GPU에서 replica당 batch 64면 global batch 256. Learning rate도 그에 맞춰 (linear scaling 규칙, 매우 큰 batch엔 warmup 추가).

Code

MirroredStrategy — 단일 머신, multi-GPU·python
import tensorflow as tf

strategy = tf.distribute.MirroredStrategy()
print(f"Replicas: {strategy.num_replicas_in_sync}")

with strategy.scope():
    model = tf.keras.applications.ResNet50(weights=None, classes=1000)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'],
    )

PER_REPLICA = 64
GLOBAL_BATCH = PER_REPLICA * strategy.num_replicas_in_sync

model.fit(train_ds, epochs=10, validation_data=val_ds)
TPUStrategy — Google Cloud TPU or Colab·python
import tensorflow as tf

# In Colab: empty TPU spec auto-detects
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)   # wipes TPU memory

strategy = tf.distribute.TPUStrategy(resolver)
print(f"TPU cores: {strategy.num_replicas_in_sync}")

with strategy.scope():
    model = tf.keras.applications.EfficientNetV2S(
        input_shape=(224, 224, 3), classes=10, weights=None,
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'],
    )

# TPU efficiency: large batches per core
BATCH_SIZE = 128 * strategy.num_replicas_in_sync
# All data must be tf.data.Dataset, not numpy arrays
model.fit(train_dataset, validation_data=val_dataset, epochs=20)

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.