BP对抗网络模拟手写数字

    技术2022-07-21  83

    from __future__ import absolute_import from __future__ import division from __future__ import print_function ### import tensorflow as tf import tensorflow.compat.v1 as tf import numpy as np import collections import gzip import os from matplotlib import pyplot as plt from tensorflow.keras.layers import Conv2D,BatchNormalization,Activation,MaxPool2D,Dropout,Flatten,Dense from tensorflow.keras import Model from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.platform import gfile from tensorflow.python.util.deprecation import deprecated _Datasets = collections.namedtuple('_Datasets', ['train', 'validation', 'test']) DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/' np.set_printoptions(threshold=np.inf) class MnistGAN(): def __init__(self): self.mnist=read_data_sets("/opt/application/tensorflow/mooc-2/mnist-gan/MNIST_data/",one_hot=True) self.batch_size=64 self.img_size=self.mnist.train._images[0].shape[0] self.chunk_size=self.mnist.train._num_examples // self.batch_size self.epoch_size=256 self.lr=1e-4 tf.disable_eager_execution() self.real_img=tf.compat.v1.placeholder(tf.float32,[None,self.img_size]) self.fake_img=tf.compat.v1.placeholder(tf.float32,[None,self.img_size]) self.leaky=0.01 self.hideceil=256 self.test=20 def generator(hideceil,img_size,leaky,inputdata,*others): with tf.variable_scope("generator"): layerH=tf.layers.dense(inputdata,hideceil) layerR=tf.maximum(layerH,leaky*layerH) drop =tf.layers.dropout(layerR,rate=0.5) logits=tf.layers.dense(drop,img_size) output=tf.tanh(logits) return logits,output def discriminator(leaky,hideceil,inputdata,reuse=False): with tf.variable_scope("discriminator",reuse=reuse): layer=tf.layers.dense(inputdata,hideceil) relu =tf.maximum(leaky*layer,layer) logits=tf.layers.dense(relu,1) output=tf.sigmoid(logits) return logits,output def loss(fake_logits,real_logits): g_loss =tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,labels=tf.ones_like(fake_logits))) d1_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,labels=tf.zeros_like(fake_logits))) d2_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits,labels=tf.ones_like(fake_logits))) d_loss =tf.add(d1_loss,d2_loss) return g_loss,d_loss def optimizer(lr,g_loss,d_loss): train_var=tf.trainable_variables() g_var=[var for var in train_var if "generator" in var.name] d_var=[var for var in train_var if "discriminator" in var.name] g_optimizer=tf.train.AdamOptimizer(lr).minimize(g_loss,var_list=g_var) d_optimizer=tf.train.AdamOptimizer(lr).minimize(d_loss,var_list=d_var) return g_optimizer,d_optimizer def train(self): print(self.hideceil, self.img_size, self.leaky, self.fake_img) gen_logits, gen_outpus = self.generator(self.hideceil, self.img_size, self.leaky, self.fake_img) fake_logits,fake_outpus = self.discriminator(self.leaky, self.hideceil, gen_outpus) real_logits,real_output = self.discriminator(self.leaky, self.hideceil, self.real_img, reuse=True) g_loss,d_loss = self.loss(fake_logits, real_logits) g_opti,dopti = self.optimizer(self.lr, g_loss, d_loss) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=False)) as sess: init=tf.global_variables_initializer() sess.run(init) for epoch in range(0,self.epoch_size): for _ in rang(0,self.chunk_size): imgs, _ = self.mnist.train.next_batch(self.batch_size) noise_img = np.random.uniform(-1,1,size=(self.batch_size,self.img_size)) sess.run(d_opti,feed_dict={self.real_img: imgs,self.fake_img: noise_img}) sess.run(g_opti,feed_dict={self.fake_img: noise_img}) gen_loss=sess.run(g_loss,feed_dict={self.fake_img:noise_img}) dis_loss=sess.run(d_loss,feed_dict={self.real_img:imgs,self.fake_img:noise_img}) print("迭代:"+str(epoch)+"g_loss="+str(gen_loss)+"d_loss="+str(dis_loss)) if (epoch % 5 ==0): noise_img=np.random.uniform(-1,1,size=(self.test,self.img_size)) samples=sess.run(gen_oupus,feed_dict={self.fake_img:noise_img}) fig,axes=plt.subplots(figsize=(7,7),nrows=5,ncols=5,sharey=True,sharex=True) for ax,img in zip(axes.platten(),samples*(-1)): ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) ax.imshow(img.reshape((28,28)),cmap="Greys_r") plt.show() def _read32(bytestream): dt = np.dtype(np.uint32).newbyteorder('>') return np.frombuffer(bytestream.read(4), dtype=dt)[0] def _extract_images(f): print('Extracting', f.name) with gzip.GzipFile(fileobj=f) as bytestream: magic = _read32(bytestream) if magic != 2051: raise ValueError('Invalid magic number %d in MNIST image file: %s' %(magic, f.name)) num_images = _read32(bytestream) rows = _read32(bytestream) cols = _read32(bytestream) buf = bytestream.read(rows * cols * num_images) data = np.frombuffer(buf, dtype=np.uint8) data = data.reshape(num_images, rows, cols, 1) return data def _dense_to_one_hot(labels_dense, num_classes): num_labels = labels_dense.shape[0] index_offset = np.arange(num_labels) * num_classes labels_one_hot = np.zeros((num_labels, num_classes)) labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 return labels_one_hot def _extract_labels(f, one_hot=False, num_classes=10): print('Extracting', f.name) with gzip.GzipFile(fileobj=f) as bytestream: magic = _read32(bytestream) if magic != 2049: raise ValueError('Invalid magic number %d in MNIST label file: %s' %(magic, f.name)) num_items = _read32(bytestream) buf = bytestream.read(num_items) labels = np.frombuffer(buf, dtype=np.uint8) if one_hot: return _dense_to_one_hot(labels, num_classes) return labels class _DataSet(object): def __init__(self,images,labels,fake_data=False,one_hot=False,dtype=dtypes.float32,reshape=True,seed=None): seed1, seed2 = random_seed.get_seed(seed) np.random.seed(seed1 if seed is None else seed2) dtype = dtypes.as_dtype(dtype).base_dtype if dtype not in (dtypes.uint8, dtypes.float32): raise TypeError('Invalid image dtype %r, expected uint8 or float32' %dtype) if fake_data: self._num_examples = 10000 self.one_hot = one_hot else: assert images.shape[0] == labels.shape[0], ('images.shape: %s labels.shape: %s' % (images.shape, labels.shape)) self._num_examples = images.shape[0] # Convert shape from [num examples, rows, columns, depth] # to [num examples, rows*columns] (assuming depth == 1) if reshape: assert images.shape[3] == 1 images = images.reshape(images.shape[0],images.shape[1] * images.shape[2]) if dtype == dtypes.float32: # Convert from [0, 255] -> [0.0, 1.0]. images = images.astype(np.float32) images = np.multiply(images, 1.0 / 255.0) self._images = images self._labels = labels self._epochs_completed = 0 self._index_in_epoch = 0 def images(self): return self._images def labels(self): return self._labels def num_examples(self): return self._num_examples def epochs_completed(self): return self._epochs_completed def next_batch(self, batch_size, fake_data=False, shuffle=True): if fake_data: fake_image = [1] * 784 if self.one_hot: fake_label = [1] + [0] * 9 else: fake_label = 0 return [fake_image for _ in xrange(batch_size)], [fake_label for _ in xrange(batch_size)] start = self._index_in_epoch # Shuffle for the first epoch if self._epochs_completed == 0 and start == 0 and shuffle: perm0 = np.arange(self._num_examples) np.random.shuffle(perm0) self._images = self.images[perm0] self._labels = self.labels[perm0] # Go to the next epoch if start + batch_size > self._num_examples: # Finished epoch self._epochs_completed += 1 # Get the rest examples in this epoch rest_num_examples = self._num_examples - start images_rest_part = self._images[start:self._num_examples] labels_rest_part = self._labels[start:self._num_examples] # Shuffle the data if shuffle: perm = np.arange(self._num_examples) np.random.shuffle(perm) self._images = self.images[perm] self._labels = self.labels[perm] # Start next epoch start = 0 self._index_in_epoch = batch_size - rest_num_examples end = self._index_in_epoch images_new_part = self._images[start:end] labels_new_part = self._labels[start:end] return np.concatenate((images_rest_part, images_new_part),axis=0), np.concatenate( (labels_rest_part, labels_new_part), axis=0) else: self._index_in_epoch += batch_size end = self._index_in_epoch return self._images[start:end], self._labels[start:end] def _maybe_download(filename, work_directory, source_url): if not gfile.Exists(work_directory): gfile.MakeDirs(work_directory) filepath = os.path.join(work_directory, filename) if not gfile.Exists(filepath): urllib.request.urlretrieve(source_url, filepath) with gfile.GFile(filepath) as f: size = f.size() print('Successfully downloaded', filename, size, 'bytes.') return filepath def read_data_sets(train_dir,fake_data=False,one_hot=False,dtype=dtypes.float32,reshape=True,validation_size=5000,seed=None,source_u rl=DEFAULT_SOURCE_URL): if fake_data: def fake(): return _DataSet([], [],fake_data=True,one_hot=one_hot,dtype=dtype,seed=seed) train = fake() validation = fake() test = fake() return _Datasets(train=train, validation=validation, test=test) if not source_url: # empty string check source_url = DEFAULT_SOURCE_URL train_images_file = 'train-images-idx3-ubyte.gz' train_labels_file = 'train-labels-idx1-ubyte.gz' test_images_file = 't10k-images-idx3-ubyte.gz' test_labels_file = 't10k-labels-idx1-ubyte.gz' local_file = _maybe_download(train_images_file, train_dir,source_url + train_images_file) with gfile.Open(local_file, 'rb') as f: train_images = _extract_images(f) local_file = _maybe_download(train_labels_file, train_dir,source_url + train_labels_file) with gfile.Open(local_file, 'rb') as f: train_labels = _extract_labels(f, one_hot=one_hot) local_file = _maybe_download(test_images_file, train_dir,source_url + test_images_file) with gfile.Open(local_file, 'rb') as f: test_images = _extract_images(f) local_file = _maybe_download(test_labels_file, train_dir,source_url + test_labels_file) with gfile.Open(local_file, 'rb') as f: test_labels = _extract_labels(f, one_hot=one_hot) if not 0 <= validation_size <= len(train_images): raise ValueError('Validation size should be between 0 and {}. Received: {}.'.format(len(train_images), validation_size)) validation_images = train_images[:validation_size] validation_labels = train_labels[:validation_size] train_images = train_images[validation_size:] train_labels = train_labels[validation_size:] options = dict(dtype=dtype, reshape=reshape, seed=seed) train = _DataSet(train_images, train_labels, **options) validation = _DataSet(validation_images, validation_labels, **options) test = _DataSet(test_images, test_labels, **options) return _Datasets(train=train, validation=validation, test=test) if __name__=="__main__": mnist=MnistGAN() mnist.train()
    Processed: 0.008, SQL: 9