imagenet图片转tfrecord(官方源码及版本导致的BUG)

    技术2022-07-11  88

    官方源码网址: https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py

    添加硬件指定,关闭eager:

    os.environ["CUDA_VISIBLE_DEVICES"] = "-1" tf.disable_eager_execution()

    修改存储路径:

    flags.DEFINE_string( 'local_scratch_dir', '/sas/imagenet-scratch', 'Scratch directory path for temporary files.') flags.DEFINE_string( 'raw_data_dir', '/sas/ImageNet', 'Directory path for raw Imagenet dataset. ' 'Should have train and validation subdirectories inside it.')

    如果不上传到gcs,可以将如下代码注释:

    from google.cloud import storage

    然后将对应标红的地方直接注释,不影响使用。

    修改变量LABELS_FILE,这个字符串对应的文件应该是验证集的标签,其label不是数字,而是和训练集文件夹名字统一的字符,如n01751748

    def make_shuffle_idx(n): order = range(n) random.shuffle(order) return order

    修改为:

    def make_shuffle_idx(n): order = range(n) order = list(order) random.shuffle(order) return order with tf.gfile.FastGFile(filename, 'r') as f: image_data = f.read()

    修改为:

    with tf.gfile.FastGFile(filename, 'rb') as f: image_data = f.read() colorspace = 'RGB' channels = 3 image_format = 'JPEG'

    修改为:

    colorspace = b'RGB' channels = 3 image_format = b'JPEG' example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': _int64_feature(height), 'image/width': _int64_feature(width), 'image/colorspace': _bytes_feature(colorspace), 'image/channels': _int64_feature(channels), 'image/class/label': _int64_feature(label), 'image/class/synset': _bytes_feature(synset), 'image/format': _bytes_feature(image_format), 'image/filename': _bytes_feature(os.path.basename(filename)), 'image/encoded': _bytes_feature(image_buffer)}))

    修改为:

    example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': _int64_feature(height), 'image/width': _int64_feature(width), 'image/colorspace': _bytes_feature(colorspace), 'image/channels': _int64_feature(channels), 'image/class/label': _int64_feature(label), 'image/class/synset': _bytes_feature(str.encode(synset)), 'image/format': _bytes_feature(image_format), 'image/filename': _bytes_feature(os.path.basename(str.encode(filename))), 'image/encoded': _bytes_feature(image_buffer)}))
    Processed: 0.010, SQL: 9