C.W.K.
Stream
Lesson 05 of 06 · published

TFRecord와 Interleave — 대규모 pipeline

~12 min · tfrecord, interleave, large-scale

Level 0Level 0
0 XP0/78 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

ImageNet 규모 dataset의 실제 모습

메모리에 안 들어가고 여러 파일에 나뉜 dataset에는 TensorFlow native binary 포맷 TFRecord 써. 각 파일이 직렬화된 tf.train.Example protocol buffer 담고 있어 — 순차 읽기 빠르고 병렬 디코딩.

interleave는 여러 TFRecord 파일 동시 읽기로 element 섞어. cycle_length=AUTOTUNE, num_parallel_calls=AUTOTUNE이랑 같이 쓰면 느린 storage (HDD, 네트워크 드라이브)에서 단일 파일 읽기가 병목인 상황의 I/O 포화시켜.

TFRecord 써야 할 때: dataset > ~5GB, 파일 수 > 100, 네트워크 파일시스템. CIFAR-10 / MNIST 크기엔 과해.

Code

Writing TFRecord·python
import tensorflow as tf

def serialize_example(image, label):
    feature = {
        'image': tf.train.Feature(
            bytes_list=tf.train.BytesList(
                value=[tf.io.encode_jpeg(image).numpy()])),
        'label': tf.train.Feature(
            int64_list=tf.train.Int64List(value=[label])),
    }
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    return example.SerializeToString()

with tf.io.TFRecordWriter("train_000.tfrecord") as writer:
    for image, label in zip(images, labels):
        writer.write(serialize_example(image, label))
Reading with interleave — parallel I/O·python
import tensorflow as tf

AUTOTUNE = tf.data.AUTOTUNE

def parse_tfrecord(serialized):
    desc = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }
    ex = tf.io.parse_single_example(serialized, desc)
    image = tf.io.decode_jpeg(ex['image'], channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.image.resize(image, [224, 224])
    return image, ex['label']

filenames = tf.data.Dataset.list_files("train_*.tfrecord")

ds = (
    filenames
    .interleave(
        lambda fn: tf.data.TFRecordDataset(fn),
        cycle_length=AUTOTUNE,
        num_parallel_calls=AUTOTUNE,
        deterministic=False,    # slight perf gain if order doesn't matter
    )
    .map(parse_tfrecord, num_parallel_calls=AUTOTUNE)
    .cache()
    .shuffle(buffer_size=10000)
    .batch(128, drop_remainder=True)
    .prefetch(AUTOTUNE)
)

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.