from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
import numpy as np
import collections
import gzip
import os
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D,BatchNormalization,Activation,MaxPool2D,Dropout,Flatten,Dense
from tensorflow.keras import Model
from six.moves import urllib
from six.moves import xrange
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.platform import gfile
from tensorflow.python.util.deprecation import deprecated
_Datasets = collections.namedtuple('_Datasets', ['train', 'validation', 'test'])
DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
np.set_printoptions(threshold=np.inf)
class MnistGAN():
def __init__(self):
self.mnist=read_data_sets("/opt/application/tensorflow/mooc-2/mnist-gan/MNIST_data/",one_hot=True)
self.batch_size=64
self.img_size=self.mnist.train._images[0].shape[0]
self.chunk_size=self.mnist.train._num_examples // self.batch_size
self.epoch_size=256
self.lr=1e-4
tf.disable_eager_execution()
self.real_img=tf.compat.v1.placeholder(tf.float32,[None,self.img_size])
self.fake_img=tf.compat.v1.placeholder(tf.float32,[None,self.img_size])
self.leaky=0.01
self.hideceil=256
self.test=20
def generator(hideceil,img_size,leaky,inputdata,*others):
with tf.variable_scope("generator"):
layerH=tf.layers.dense(inputdata,hideceil)
layerR=tf.maximum(layerH,leaky*layerH)
drop =tf.layers.dropout(layerR,rate=0.5)
logits=tf.layers.dense(drop,img_size)
output=tf.tanh(logits)
return logits,output
def discriminator(leaky,hideceil,inputdata,reuse=False):
with tf.variable_scope("discriminator",reuse=reuse):
layer=tf.layers.dense(inputdata,hideceil)
relu =tf.maximum(leaky*layer,layer)
logits=tf.layers.dense(relu,1)
output=tf.sigmoid(logits)
return logits,output
def loss(fake_logits,real_logits):
g_loss =tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,labels=tf.ones_like(fake_logits)))
d1_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,labels=tf.zeros_like(fake_logits)))
d2_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits,labels=tf.ones_like(fake_logits)))
d_loss =tf.add(d1_loss,d2_loss)
return g_loss,d_loss
def optimizer(lr,g_loss,d_loss):
train_var=tf.trainable_variables()
g_var=[var for var in train_var if "generator" in var.name]
d_var=[var for var in train_var if "discriminator" in var.name]
g_optimizer=tf.train.AdamOptimizer(lr).minimize(g_loss,var_list=g_var)
d_optimizer=tf.train.AdamOptimizer(lr).minimize(d_loss,var_list=d_var)
return g_optimizer,d_optimizer
def train(self):
print(self.hideceil, self.img_size, self.leaky, self.fake_img)
gen_logits, gen_outpus = self.generator(self.hideceil, self.img_size, self.leaky, self.fake_img)
fake_logits,fake_outpus = self.discriminator(self.leaky, self.hideceil, gen_outpus)
real_logits,real_output = self.discriminator(self.leaky, self.hideceil, self.real_img, reuse=True)
g_loss,d_loss = self.loss(fake_logits, real_logits)
g_opti,dopti = self.optimizer(self.lr, g_loss, d_loss)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=False)) as sess:
init=tf.global_variables_initializer()
sess.run(init)
for epoch in range(0,self.epoch_size):
for _ in rang(0,self.chunk_size):
imgs, _ = self.mnist.train.next_batch(self.batch_size)
noise_img = np.random.uniform(-1,1,size=(self.batch_size,self.img_size))
sess.run(d_opti,feed_dict={self.real_img: imgs,self.fake_img: noise_img})
sess.run(g_opti,feed_dict={self.fake_img: noise_img})
gen_loss=sess.run(g_loss,feed_dict={self.fake_img:noise_img})
dis_loss=sess.run(d_loss,feed_dict={self.real_img:imgs,self.fake_img:noise_img})
print("迭代:"+str(epoch)+"g_loss="+str(gen_loss)+"d_loss="+str(dis_loss))
if (epoch % 5 ==0):
noise_img=np.random.uniform(-1,1,size=(self.test,self.img_size))
samples=sess.run(gen_oupus,feed_dict={self.fake_img:noise_img})
fig,axes=plt.subplots(figsize=(7,7),nrows=5,ncols=5,sharey=True,sharex=True)
for ax,img in zip(axes.platten(),samples*(-1)):
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.imshow(img.reshape((28,28)),cmap="Greys_r")
plt.show()
def _read32(bytestream):
dt = np.dtype(np.uint32).newbyteorder('>')
return np.frombuffer(bytestream.read(4), dtype=dt)[0]
def _extract_images(f):
print('Extracting', f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST image file: %s' %(magic, f.name))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = np.frombuffer(buf, dtype=np.uint8)
data = data.reshape(num_images, rows, cols, 1)
return data
def _dense_to_one_hot(labels_dense, num_classes):
num_labels = labels_dense.shape[0]
index_offset = np.arange(num_labels) * num_classes
labels_one_hot = np.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
def _extract_labels(f, one_hot=False, num_classes=10):
print('Extracting', f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST label file: %s' %(magic, f.name))
num_items = _read32(bytestream)
buf = bytestream.read(num_items)
labels = np.frombuffer(buf, dtype=np.uint8)
if one_hot:
return _dense_to_one_hot(labels, num_classes)
return labels
class _DataSet(object):
def __init__(self,images,labels,fake_data=False,one_hot=False,dtype=dtypes.float32,reshape=True,seed=None):
seed1, seed2 = random_seed.get_seed(seed)
np.random.seed(seed1 if seed is None else seed2)
dtype = dtypes.as_dtype(dtype).base_dtype
if dtype not in (dtypes.uint8, dtypes.float32):
raise TypeError('Invalid image dtype %r, expected uint8 or float32' %dtype)
if fake_data:
self._num_examples = 10000
self.one_hot = one_hot
else:
assert images.shape[0] == labels.shape[0], ('images.shape: %s labels.shape: %s' % (images.shape, labels.shape))
self._num_examples = images.shape[0]
if reshape:
assert images.shape[3] == 1
images = images.reshape(images.shape[0],images.shape[1] * images.shape[2])
if dtype == dtypes.float32:
images = images.astype(np.float32)
images = np.multiply(images, 1.0 / 255.0)
self._images = images
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0
def images(self):
return self._images
def labels(self):
return self._labels
def num_examples(self):
return self._num_examples
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size, fake_data=False, shuffle=True):
if fake_data:
fake_image = [1] * 784
if self.one_hot:
fake_label = [1] + [0] * 9
else:
fake_label = 0
return [fake_image for _ in xrange(batch_size)], [fake_label for _ in xrange(batch_size)]
start = self._index_in_epoch
if self._epochs_completed == 0 and start == 0 and shuffle:
perm0 = np.arange(self._num_examples)
np.random.shuffle(perm0)
self._images = self.images[perm0]
self._labels = self.labels[perm0]
if start + batch_size > self._num_examples:
self._epochs_completed += 1
rest_num_examples = self._num_examples - start
images_rest_part = self._images[start:self._num_examples]
labels_rest_part = self._labels[start:self._num_examples]
if shuffle:
perm = np.arange(self._num_examples)
np.random.shuffle(perm)
self._images = self.images[perm]
self._labels = self.labels[perm]
start = 0
self._index_in_epoch = batch_size - rest_num_examples
end = self._index_in_epoch
images_new_part = self._images[start:end]
labels_new_part = self._labels[start:end]
return np.concatenate((images_rest_part, images_new_part),axis=0), np.concatenate(
(labels_rest_part, labels_new_part), axis=0)
else:
self._index_in_epoch += batch_size
end = self._index_in_epoch
return self._images[start:end], self._labels[start:end]
def _maybe_download(filename, work_directory, source_url):
if not gfile.Exists(work_directory):
gfile.MakeDirs(work_directory)
filepath = os.path.join(work_directory, filename)
if not gfile.Exists(filepath):
urllib.request.urlretrieve(source_url, filepath)
with gfile.GFile(filepath) as f:
size = f.size()
print('Successfully downloaded', filename, size, 'bytes.')
return filepath
def read_data_sets(train_dir,fake_data=False,one_hot=False,dtype=dtypes.float32,reshape=True,validation_size=5000,seed=None,source_u
rl=DEFAULT_SOURCE_URL):
if fake_data:
def fake():
return _DataSet([], [],fake_data=True,one_hot=one_hot,dtype=dtype,seed=seed)
train = fake()
validation = fake()
test = fake()
return _Datasets(train=train, validation=validation, test=test)
if not source_url:
source_url = DEFAULT_SOURCE_URL
train_images_file = 'train-images-idx3-ubyte.gz'
train_labels_file = 'train-labels-idx1-ubyte.gz'
test_images_file = 't10k-images-idx3-ubyte.gz'
test_labels_file = 't10k-labels-idx1-ubyte.gz'
local_file = _maybe_download(train_images_file, train_dir,source_url + train_images_file)
with gfile.Open(local_file, 'rb') as f:
train_images = _extract_images(f)
local_file = _maybe_download(train_labels_file, train_dir,source_url + train_labels_file)
with gfile.Open(local_file, 'rb') as f:
train_labels = _extract_labels(f, one_hot=one_hot)
local_file = _maybe_download(test_images_file, train_dir,source_url + test_images_file)
with gfile.Open(local_file, 'rb') as f:
test_images = _extract_images(f)
local_file = _maybe_download(test_labels_file, train_dir,source_url + test_labels_file)
with gfile.Open(local_file, 'rb') as f:
test_labels = _extract_labels(f, one_hot=one_hot)
if not 0 <= validation_size <= len(train_images):
raise ValueError('Validation size should be between 0 and {}. Received: {}.'.format(len(train_images), validation_size))
validation_images = train_images[:validation_size]
validation_labels = train_labels[:validation_size]
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]
options = dict(dtype=dtype, reshape=reshape, seed=seed)
train = _DataSet(train_images, train_labels, **options)
validation = _DataSet(validation_images, validation_labels, **options)
test = _DataSet(test_images, test_labels, **options)
return _Datasets(train=train, validation=validation, test=test)
if __name__=="__main__":
mnist=MnistGAN()
mnist.train()
转载请注明原文地址:https://ipadbbs.8miu.com/read-29879.html