官方源码网址: https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py
添加硬件指定,关闭eager:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" tf.disable_eager_execution()修改存储路径:
flags.DEFINE_string( 'local_scratch_dir', '/sas/imagenet-scratch', 'Scratch directory path for temporary files.') flags.DEFINE_string( 'raw_data_dir', '/sas/ImageNet', 'Directory path for raw Imagenet dataset. ' 'Should have train and validation subdirectories inside it.')如果不上传到gcs,可以将如下代码注释:
from google.cloud import storage然后将对应标红的地方直接注释,不影响使用。
修改变量LABELS_FILE,这个字符串对应的文件应该是验证集的标签,其label不是数字,而是和训练集文件夹名字统一的字符,如n01751748
def make_shuffle_idx(n): order = range(n) random.shuffle(order) return order修改为:
def make_shuffle_idx(n): order = range(n) order = list(order) random.shuffle(order) return order with tf.gfile.FastGFile(filename, 'r') as f: image_data = f.read()修改为:
with tf.gfile.FastGFile(filename, 'rb') as f: image_data = f.read() colorspace = 'RGB' channels = 3 image_format = 'JPEG'修改为:
colorspace = b'RGB' channels = 3 image_format = b'JPEG' example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': _int64_feature(height), 'image/width': _int64_feature(width), 'image/colorspace': _bytes_feature(colorspace), 'image/channels': _int64_feature(channels), 'image/class/label': _int64_feature(label), 'image/class/synset': _bytes_feature(synset), 'image/format': _bytes_feature(image_format), 'image/filename': _bytes_feature(os.path.basename(filename)), 'image/encoded': _bytes_feature(image_buffer)}))修改为:
example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': _int64_feature(height), 'image/width': _int64_feature(width), 'image/colorspace': _bytes_feature(colorspace), 'image/channels': _int64_feature(channels), 'image/class/label': _int64_feature(label), 'image/class/synset': _bytes_feature(str.encode(synset)), 'image/format': _bytes_feature(image_format), 'image/filename': _bytes_feature(os.path.basename(str.encode(filename))), 'image/encoded': _bytes_feature(image_buffer)}))