本文使用代码实现前文描述的算法原理以及网络结构。同时驱动训练流程,使得生成者网络具备对给定灰度图正确上色的能力。首先使用代码构造生成者和鉴别者网络:
class Generator(tf.keras.Model):
def __init__(self, encoder_kernel, decoder_kernel):
super(Generator, self).__init__()
self.encoder_kernels = encoder_kernel#对应卷积层参数
self.decoder_kernels = decoder_kernel #对应反卷积层参数
self.kernel_size = 4
self.output_channels = 3#最终输出RGB颜色图像图像
self.left_size_layers = []
self.right_size_layers = []
self.last_layers = []
self.create_network()
def create_network(self): #构建生成者网络
for index, kernel in enumerate(self.encoder_kernels): #设立卷积层识别输入图像规律
down_sample_layers = []
down_sample_layers.append(tf.keras.layers.Conv2D(
kernel_size = self.kernel_size,
filters = kernel[0],
strides = kernel[1],
padding = 'same'
))
down_sample_layers.append(tf.keras.layers.BatchNormalization())
down_sample_layers.append(tf.keras.layers.LeakyReLU())
self.left_size_layers.append(down_sample_layers)
for index, kernel in enumerate(self.decoder_kernels):#设立反卷积层,实现像素点颜色赋值
up_sample_layers = []
up_sample_layers.append(tf.keras.layers.Conv2DTranspose(
kernel_size = self.kernel_size,
filters = kernel[0],
strides = kernel[1],
padding = 'same'
))
self.discriminator_layers.append(tf.keras.layers.BatchNormalization())
self.discriminator_layers.append(tf.keras.layers.LeakyReLU())
self.discriminator_layers.append(tf.keras.layers.Conv2D(
kernel_size = 4,
filters = 1,
strides = 1,
padding = 'same'
))#输出表明着色正确性的数值
def call(self, x):
x = tf.convert_to_tensor(x, dtype = tf.float32)
for layer in self.discriminator_layers:
x = layer(x)
return x
def create_variables(self):#生成鉴别者网络内部参数
dummy1 = np.zeros((1, 256, 256, 1))
dummy2 = np.zeros((1, 256, 256, 3))
x = np.concatenate((dummy1, dummy2), axis = 3)
self.call(x)
接下来需要完成数据预处理代码,它包括图片的加载,彩色图片转换为黑白图片,LAB和RGB图片格式互换等。为了提升着色效果,算法让生成者网络构造的图片遵循LAB格式,这是因为图片效果好坏非常依赖于色彩的亮度。
RGB图片格式无法表达色彩亮度,后LAB可以,因此训练生成者网络生成LAB格式图片就能让网络把握图片中色彩亮度,这样能有效提升着色效果,当有了LAB格式图片后,在进行展示时,代码再将其转换为RGB格式,下面展示相关实现代码:
import os
import sys
import time
import random
import pickle
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
def stitch_images(grayscale, original, pred):#将灰度图,对应的彩色图,以及生成者网络上色后的结果“缝合”在一起
gap = 5
width, height = original[0][:, :, 0].shape
img_per_row = 2 if width > 200 else 4
img = Image.new('RGB', (width * img_per_row * 3 + gap * (img_per_row - 1), height * int(len(original) / img_per_row)))
grayscale = np.array(grayscale).squeeze()
original = np.array(original)
pred = np.array(pred)
for ix in range(len(original)):
xoffset = int(ix % img_per_row) * width * 3 + int(ix % img_per_row) * gap
yoffset = int(ix / img_per_row) * height
im1 = Image.fromarray(grayscale[ix])
im2 = Image.fromarray(original[ix])
im3 = Image.fromarray((pred[ix] * 255).astype(np.uint8))
img.paste(im1, (xoffset, yoffset))
img.paste(im2, (xoffset + width, yoffset))
img.paste(im3, (xoffset + width + width, yoffset))
return img
def imshow(img, title=''):#展示图片
fig = plt.gcf()
fig.canvas.set_window_title(title)
plt.axis('off')
plt.imshow(img, interpolation='none')
plt.show()
import numpy as np
import tensorflow as tf
COLORSPACE_RGB = 'RGB'
COLORSPACE_LAB = 'LAB' #RGB与LAB格式互换
def preprocess(img, colorspace_in, colorspace_out):
if colorspace_out.upper() == COLORSPACE_RGB:
if colorspace_in == COLORSPACE_LAB:
img = lab_to_rgb(img)
# [0, 1] => [-1, 1]
img = (img / 255.0) * 2 - 1
elif colorspace_out.upper() == COLORSPACE_LAB:
if colorspace_in == COLORSPACE_RGB:
img = rgb_to_lab(img / 255.0)
L_chan, a_chan, b_chan = tf.unstack(img, axis=3)
# L: [0, 100] => [-1, 1]
# A, B: [-110, 110] => [-1, 1]
img = tf.stack([L_chan / 50 - 1, a_chan / 110, b_chan / 110], axis=3)
return img
def postprocess(img, colorspace_in, colorspace_out):
if colorspace_in.upper() == COLORSPACE_RGB:
# [-1, 1] => [0, 1]
img = (img + 1) / 2
if colorspace_out == COLORSPACE_LAB:
img = rgb_to_lab(img)
elif colorspace_in.upper() == COLORSPACE_LAB:
L_chan, a_chan, b_chan = tf.unstack(img, axis=3)
# L: [-1, 1] => [0, 100]
# A, B: [-1, 1] => [-110, 110]
img = tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3)
if colorspace_out == COLORSPACE_RGB:
img = lab_to_rgb(img)
return img
def rgb_to_lab(srgb):
# based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c
with tf.name_scope("rgb_to_lab"):
srgb_pixels = tf.reshape(srgb, [-1, 3])
srgb_pixels = tf.cast(srgb_pixels, tf.float32)
with tf.name_scope("srgb_to_xyz"):
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
rgb_to_xyz = tf.constant([
# X Y Z
[0.412453, 0.212671, 0.019334], # R
[0.357580, 0.715160, 0.119193], # G
[0.180423, 0.072169, 0.950227], # B
])
xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("xyz_to_cielab"):
# normalize for D65 white point
xyz_normalized_pixels = tf.multiply(xyz_pixels, [1 / 0.950456, 1.0, 1 / 1.088754])
epsilon = 6 / 29
linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4 / 29) * linear_mask + (xyz_normalized_pixels ** (1 / 3)) * exponential_mask
# convert to lab
fxfyfz_to_lab = tf.constant([
# l a b
[0.0, 500.0, 0.0], # fx
[116.0, -500.0, 200.0], # fy
[0.0, 0.0, -200.0], # fz
])
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
return tf.reshape(lab_pixels, tf.shape(srgb))
def lab_to_rgb(lab):
with tf.name_scope("lab_to_rgb"):
lab_pixels = tf.reshape(lab, [-1, 3])
lab_pixels = tf.cast(lab_pixels, tf.float32)
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("cielab_to_xyz"):
# convert to fxfyfz
lab_to_fxfyfz = tf.constant([
# fx fy fz
[1 / 116.0, 1 / 116.0, 1 / 116.0], # l
[1 / 500.0, 0.0, 0.0], # a
[0.0, 0.0, -1 / 200.0], # b
])
fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
# convert to xyz
epsilon = 6 / 29
linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4 / 29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
# denormalize for D65 white point
xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
with tf.name_scope("xyz_to_srgb"):
xyz_to_rgb = tf.constant([
# r g b
[3.2404542, -0.9692660, 0.0556434], # x
[-1.5371385, 1.8760108, -0.2040259], # y
[-0.4985314, 0.0415560, 1.0572252], # z
])
rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
# avoid a slightly negative number messing up the conversion
rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1 / 2.4) * 1.055) - 0.055) * exponential_mask
return tf.reshape(srgb_pixels, tf.shape(lab))
这些代码代码较长,其中实现了RGB与LAB格式互换等相关辅助功能。由于这些代码其实现逻辑与算法主逻辑并无直接关联,因此读者对其有大概了解即可,无需投入过多精力。接下来实现数据集的加载,此次网络训练使用的数据集是place365,使用下面代码实现数据的解压和加载:
!tar -xvf '/content/drive/Shared drives/chenyi19820904.edu.us/place365_dataset/test-256.tar' #解压数据集
使用如上命令将数据集解压到硬盘上后,使用如下代码实现数据加载:
import os
import glob
import numpy as np
import tensorflow as tf
from scipy.misc import imread
from abc import abstractmethod
PLACES365_DATASET = 'places365'
class BaseDataset(): #将图片数据依次读入内存以便于用于训练网络
def __init__(self, name, path, training=True, augment=True):
self.name = name
self.augment = augment and training
self.training = training
self.path = path
self._data = []
def __len__(self):
return len(self.data)
def __iter__(self):
total = len(self)
start = 0
while start < total:
item = self[start]
start += 1
yield item
raise StopIteration
def __getitem__(self, index):
val = self.data[index]
try:
img = imread(val) if isinstance(val, str) else val
# grayscale images
if np.sum(img[:,:,0] - img[:,:,1]) == 0 and np.sum(img[:,:,0] - img[:,:,2]) == 0:
return None
if self.augment and np.random.binomial(1, 0.5) == 1:
img = img[:, ::-1, :]
except:
img = None
return img
def generator(self, batch_size, recusrive=False):
start = 0
total = len(self)
while True:
while start < total:
end = np.min([start + batch_size, total])
items = []
for ix in range(start, end):
item = self[ix]
if item is not None:
items.append(item)
start = end
yield items
if recusrive:
start = 0
else:
raise StopIteration
@property
def data(self):
if len(self._data) == 0:
self._data = self.load()
np.random.shuffle(self._data)
return self._data
@abstractmethod
def load(self):
return []
class Places365Dataset(BaseDataset):
def __init__(self, path, training=True, augment=True):
super(Places365Dataset, self).__init__(PLACES365_DATASET, path, training, augment)
def load(self): #加载图片数据
data = glob.glob(self.path + '/*.jpg', recursive=True)
return data
准备好了数据和网络之和,使用如下代码驱动训练流程的进行:
class Generator(tf.keras.Model):
def __init__(self, encoder_kernel, decoder_kernel):
super(Generator, self).__init__()
self.encoder_kernels = encoder_kernel#对应卷积层参数
self.decoder_kernels = decoder_kernel #对应反卷积层参数
self.kernel_size = 4
self.output_channels = 3#最终输出RGB颜色图像图像
self.left_size_layers = []
self.right_size_layers = []
self.last_layers = []
self.create_network()
def create_network(self): #构建生成者网络
for index, kernel in enumerate(self.encoder_kernels): #设立卷积层识别输入图像规律
down_sample_layers = []
down_sample_layers.append(tf.keras.layers.Conv2D(
kernel_size = self.kernel_size,
filters = kernel[0],
strides = kernel[1],
padding = 'same'
))
down_sample_layers.append(tf.keras.layers.BatchNormalization())
down_sample_layers.append(tf.keras.layers.LeakyReLU())
self.left_size_layers.append(down_sample_layers)
for index, kernel in enumerate(self.decoder_kernels):#设立反卷积层,实现像素点颜色赋值
up_sample_layers = []
up_sample_layers.append(tf.keras.layers.Conv2DTranspose(
kernel_size = self.kernel_size,
filters = kernel[0],
strides = kernel[1],
padding = 'same'
))
up_sample_layers.append(tf.keras.layers.BatchNormalization())
up_sample_layers.append(tf.keras.layers.ReLU())
self.right_size_layers.append(up_sample_layers)
self.last_layers.append(tf.keras.layers.Conv2D(
kernel_size = 1,
filters = self.output_channels,
strides = 1,
padding = 'same',
activation = 'tanh'
))#生成彩色图像
def call(self, x):
x = tf.convert_to_tensor(x, dtype = tf.float32)
left_layer_results = []
for layers in self.left_size_layers:
for layer in layers:
x = layer(x)
left_layer_results.append(x)
left_layer_results.reverse()
idx = 0
x = left_layer_results[idx]
for layers in self.right_size_layers:
conresponding_left = left_layer_results[idx + 1] #将对应的卷积层输出直接提交给对应的反卷积层
idx += 1
for layer in layers:
x = layer(x)
x = tf.keras.layers.concatenate([x, conresponding_left])
for layers in self.last_layers:
x = layers(x)
return x
def create_variables(self):#构造网络参数
dummy1 = np.zeros((1, 256, 256, 1))
self.call(dummy1)
class Discriminator(tf.keras.Model):
def __init__(self, encoder_kernel):
super(Discriminator, self).__init__()
self.kernels = encoder_kernel #鉴别者网络卷积层参数
self.discriminator_layers = []
self.kernel_size = 4
self.create_network()
def create_network(self):
for index, kernel in enumerate(self.kernels):#构造卷积层识别输入图像规律
self.discriminator_layers.append(tf.keras.layers.Conv2D(
kernel_size = self.kernel_size,
filters = kernel[0],
strides = kernel[1],
padding = 'same'
))
self.discriminator_layers.append(tf.keras.layers.BatchNormalization())
self.discriminator_layers.append(tf.keras.layers.LeakyReLU())
self.discriminator_layers.append(tf.keras.layers.Conv2D(
kernel_size = 4,
filters = 1,
strides = 1,
padding = 'same'
))#输出表明着色正确性的数值
def call(self, x):
x = tf.convert_to_tensor(x, dtype = tf.float32)
for layer in self.discriminator_layers:
x = layer(x)
return x
def create_variables(self):#生成鉴别者网络内部参数
dummy1 = np.zeros((1, 256, 256, 1))
dummy2 = np.zeros((1, 256, 256, 3))
x = np.concatenate((dummy1, dummy2), axis = 3)
self.call(x)
class ColorGAN:
def __init__(self):
self.generator = None
self.discriminator = None
self.global_step = tf.Variable(0, dtype = tf.float32, trainable=False)
self.create_generator_discriminator()
self.data_generator = self.create_dataset(True)#加载训练数据集
self.dataset_val = self.create_dataset(False)
self.sample_generator = self.dataset_val.generator(8, True)
self.learning_rate = tf.compat.v1.train.exponential_decay(
learning_rate = 3e-4, global_step = self.global_step,
decay_steps = 1e-5, decay_rate = 0.1
) #生成者网络训练时需要学习率不断变化
self.generator_optimizer = tf.optimizers.Adam(self.learning_rate, beta_1 = 0)
self.discriminator_optimizer = tf.optimizers.Adam(3e-5, beta_1 = 0)
self.batch_size = 16
self.epochs = 5
self.epoch = 0
self.step = 0
self.run_folder = "/content/drive/My Drive/ColorGAN/models/"
#self.load_model() #反注释该语句可实现网络参数直接加载
def create_generator_discriminator(self): #构造生成者和鉴别者
generator_encoder = [ #第一个数值对应filter,第二个参数对应stride,kernel大小始终保持4
(64, 1),
(64, 2),
(128, 2),
(256, 2),
(512, 2),
(512, 2),
(512, 2),
(512, 2)
] #生成者网络卷积层参数
generator_decoder = [
(512, 2),
(512, 2),
(512, 2),
(256, 2),
(128, 2),
(64, 2),
(64, 2)
]#生成者网络反卷积层参数
self.generator = Generator(generator_encoder, generator_decoder)
self.generator.create_variables()
discriminator_decoder = [
(64, 2),
(128, 2),
(256, 2),
(512, 1)
]#鉴别者网络卷积层参数
self.discriminator = Discriminator(discriminator_decoder)
self.discriminator.create_variables()
def train(self):
for epoch in range(self.epochs): #加载训练数据训练生成者和鉴别者网络
data_gen = self.data_generator.generator(16)
for img in data_gen:
img = np.array(img)
self.train_discriminator(img)
self.train_generator(img)
self.train_generator(img)
self.step += 1
if self.step % 100 == 0: #显示训练效果
display.clear_output(wait = True)
self.sample()
self.save_model()
def train_discriminator(self, img_color):
img_gray = tf.image.rgb_to_grayscale(img_color) #将图片转换为灰度图
img_gray = tf.cast(img_gray, tf.float32)
lab_color_img = preprocess(img_color, colorspace_in = COLORSPACE_RGB,
colorspace_out = COLORSPACE_LAB)#转换为LAB格式
gen_img = self.generator(img_gray)
real_img = tf.concat([img_gray, lab_color_img], 3)
fake_img = tf.concat([img_gray, gen_img], 3)
with tf.GradientTape(watch_accessed_variables=False) as tape:
#如果输入数据是灰度图和原色彩图则让输出数值尽可能大,如果是生成者网络构造的色彩图则输出尽可能小
tape.watch(self.discriminator.trainable_variables)
discrinimator_real = self.discriminator(real_img, training = True)
discriminator_fake = self.discriminator(fake_img, training = True)
loss_real = tf.nn.sigmoid_cross_entropy_with_logits(logits = discrinimator_real,
labels = tf.ones_like(discrinimator_real))
loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = discriminator_fake,
labels = tf.zeros_like(discriminator_fake))
discriminator_loss = tf.reduce_mean(tf.reduce_mean(loss_real) + tf.reduce_mean(loss_fake))
grads = tape.gradient(discriminator_loss, self.discriminator.trainable_variables)
self.discriminator_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_variables))
def train_generator(self, img_color):
img_gray = tf.image.rgb_to_grayscale(img_color)
img_gray = tf.cast(img_gray, tf.float32)
lab_color_img = preprocess(img_color, colorspace_in = COLORSPACE_RGB,
colorspace_out = COLORSPACE_LAB)#转换为LAB格式
with tf.GradientTape(watch_accessed_variables=False) as tape:
#让生成的彩色图和灰度图输入鉴别者网络后所得结果尽可能大
tape.watch(self.generator.trainable_variables)
gen_img = self.generator(tf.cast(img_gray, tf.float32), training = True)
fake_img = tf.concat([img_gray, gen_img], 3)
tape.watch(self.generator.trainable_variables)
discriminator_fake = self.discriminator(fake_img, training = True)
loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = discriminator_fake,
labels = tf.ones_like(discriminator_fake) ) #尽可能通过鉴别者网络的审查
generator_discriminator_loss = tf.reduce_mean(loss_fake)
generator_content_loss = tf.reduce_mean(tf.abs(lab_color_img - gen_img)) * 100.0 #保证生成图片物体与输入图片物体尽可能在形状上相同
generator_loss = generator_discriminator_loss + generator_content_loss
grads = tape.gradient(generator_loss, self.generator.trainable_variables)
self.generator_optimizer.apply_gradients(zip(grads, self.generator.trainable_variables))
def sample(self): #检验genertor的上色效果
input_imgs = next(self.sample_generator)
gray_imgs = tf.image.rgb_to_grayscale(input_imgs)
gray_imgs = tf.cast(gray_imgs, tf.float32)
fake_imgs = self.generator(gray_imgs, training = True)
fake_imgs = postprocess(tf.convert_to_tensor(fake_imgs), colorspace_in = COLORSPACE_LAB,
colorspace_out = COLORSPACE_RGB)
img_show = stitch_images(gray_imgs, input_imgs, fake_imgs.numpy()) #将三张图片贴在一起
imshow(np.array(img_show), "color_gan")
def save_model(self): #保存当前网络参数
self.discriminator.save_weights(self.run_folder + "discriminator.h5")
self.generator.save_weights(self.run_folder + "generator.h5")
def load_model(self):#加载网络参数
self.discriminator.load_weights(self.run_folder + "discriminator.h5")
self.generator.load_weights(self.run_folder + "generator.h5")
def create_dataset(self, training): #创建训练数据集
return Places365Dataset(
path= '/content/test_256/',
training=training,
augment= True)
import os
from IPython import display
gan = ColorGAN()
gan.train()#启动训练流程
运行代码后就能启动网络训练流程。该训练流程较为耗时,读者可以从随书目录中加载笔者已经训练好的网络以便直接查看训练结果。经过长时间训练后,笔者在体验上色效果时发现一个有趣现象,那就是有时上色后的图片比原来彩色图片具有更好的美感或艺术效果,以下是训练后网络实现的上色效果,如图4:
页面更新:2024-05-11
本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828
© CopyRight 2020-2024 All Rights Reserved. Powered By 71396.com 闽ICP备11008920号-4
闽公网安备35020302034903号