如何实现黑白图片自动上色?(代码实现)

本文使用代码实现前文描述的算法原理以及网络结构。同时驱动训练流程,使得生成者网络具备对给定灰度图正确上色的能力。首先使用代码构造生成者和鉴别者网络:

class  Generator(tf.keras.Model):
    def  __init__(self, encoder_kernel, decoder_kernel):
        super(Generator, self).__init__()
        self.encoder_kernels = encoder_kernel#对应卷积层参数
        self.decoder_kernels = decoder_kernel #对应反卷积层参数
        self.kernel_size = 4
        self.output_channels = 3#最终输出RGB颜色图像图像
        self.left_size_layers = []
        self.right_size_layers = []
        self.last_layers = []
        self.create_network()
    def  create_network(self): #构建生成者网络
        for index, kernel in enumerate(self.encoder_kernels): #设立卷积层识别输入图像规律
            down_sample_layers = []
            down_sample_layers.append(tf.keras.layers.Conv2D(
                kernel_size = self.kernel_size,
                filters = kernel[0],
                strides = kernel[1],
                padding = 'same'
            ))
            down_sample_layers.append(tf.keras.layers.BatchNormalization())
            down_sample_layers.append(tf.keras.layers.LeakyReLU())
            self.left_size_layers.append(down_sample_layers)
        for index, kernel in enumerate(self.decoder_kernels):#设立反卷积层,实现像素点颜色赋值
            up_sample_layers = []
            up_sample_layers.append(tf.keras.layers.Conv2DTranspose(
                kernel_size = self.kernel_size,
                filters = kernel[0],
                strides = kernel[1],
                padding = 'same'
            ))


self.discriminator_layers.append(tf.keras.layers.BatchNormalization())
            self.discriminator_layers.append(tf.keras.layers.LeakyReLU())
        self.discriminator_layers.append(tf.keras.layers.Conv2D(
                kernel_size = 4,
                filters = 1,
                strides = 1,
                padding = 'same'
            ))#输出表明着色正确性的数值
    def  call(self, x):
        x = tf.convert_to_tensor(x, dtype = tf.float32)
        for  layer in  self.discriminator_layers:
            x = layer(x)
        return x
    def  create_variables(self):#生成鉴别者网络内部参数
        dummy1 = np.zeros((1, 256, 256, 1))
        dummy2 = np.zeros((1, 256, 256, 3))
        x = np.concatenate((dummy1, dummy2), axis = 3)
        self.call(x)

接下来需要完成数据预处理代码,它包括图片的加载,彩色图片转换为黑白图片,LAB和RGB图片格式互换等。为了提升着色效果,算法让生成者网络构造的图片遵循LAB格式,这是因为图片效果好坏非常依赖于色彩的亮度。

RGB图片格式无法表达色彩亮度,后LAB可以,因此训练生成者网络生成LAB格式图片就能让网络把握图片中色彩亮度,这样能有效提升着色效果,当有了LAB格式图片后,在进行展示时,代码再将其转换为RGB格式,下面展示相关实现代码:

import os
import sys
import time
import random
import pickle
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
def stitch_images(grayscale, original, pred):#将灰度图,对应的彩色图,以及生成者网络上色后的结果“缝合”在一起
    gap = 5
    width, height = original[0][:, :, 0].shape
    img_per_row = 2 if width > 200 else 4
    img = Image.new('RGB', (width * img_per_row * 3 + gap * (img_per_row - 1), height * int(len(original) / img_per_row)))
    grayscale = np.array(grayscale).squeeze()
    original = np.array(original)
    pred = np.array(pred)
    for ix in range(len(original)):
        xoffset = int(ix % img_per_row) * width * 3 + int(ix % img_per_row) * gap
        yoffset = int(ix / img_per_row) * height
        im1 = Image.fromarray(grayscale[ix])
        im2 = Image.fromarray(original[ix])
im3 = Image.fromarray((pred[ix] * 255).astype(np.uint8))
        img.paste(im1, (xoffset, yoffset))
        img.paste(im2, (xoffset + width, yoffset))
        img.paste(im3, (xoffset + width + width, yoffset))
    return img
def imshow(img, title=''):#展示图片
    fig = plt.gcf()
    fig.canvas.set_window_title(title)
    plt.axis('off')
    plt.imshow(img, interpolation='none')
    plt.show()
import numpy as np
import tensorflow as tf
COLORSPACE_RGB = 'RGB'
COLORSPACE_LAB = 'LAB' #RGB与LAB格式互换
def preprocess(img, colorspace_in, colorspace_out):
    if colorspace_out.upper() == COLORSPACE_RGB:
        if colorspace_in == COLORSPACE_LAB:
            img = lab_to_rgb(img)


        # [0, 1] => [-1, 1]
        img = (img / 255.0) * 2 - 1


    elif colorspace_out.upper() == COLORSPACE_LAB:
        if colorspace_in == COLORSPACE_RGB:
            img = rgb_to_lab(img / 255.0)
        L_chan, a_chan, b_chan = tf.unstack(img, axis=3)


        # L: [0, 100] => [-1, 1]
        # A, B: [-110, 110] => [-1, 1]
        img = tf.stack([L_chan / 50 - 1, a_chan / 110, b_chan / 110], axis=3)
    return img
def postprocess(img, colorspace_in, colorspace_out):
    if colorspace_in.upper() == COLORSPACE_RGB:
        # [-1, 1] => [0, 1]
        img = (img + 1) / 2


        if colorspace_out == COLORSPACE_LAB:
            img = rgb_to_lab(img)
    elif colorspace_in.upper() == COLORSPACE_LAB:
        L_chan, a_chan, b_chan = tf.unstack(img, axis=3)
        # L: [-1, 1] => [0, 100]
        # A, B: [-1, 1] => [-110, 110]
        img = tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3)


        if colorspace_out == COLORSPACE_RGB:
            img = lab_to_rgb(img)
    return img
def rgb_to_lab(srgb):
 # based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c
    with tf.name_scope("rgb_to_lab"):
        srgb_pixels = tf.reshape(srgb, [-1, 3])
        srgb_pixels =  tf.cast(srgb_pixels, tf.float32)
        with tf.name_scope("srgb_to_xyz"):
            linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
            exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
            rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
            rgb_to_xyz = tf.constant([
                #    X        Y          Z
                [0.412453, 0.212671, 0.019334],  # R
                [0.357580, 0.715160, 0.119193],  # G
                [0.180423, 0.072169, 0.950227],  # B
            ])
            xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)


        # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
        with tf.name_scope("xyz_to_cielab"):


            # normalize for D65 white point
            xyz_normalized_pixels = tf.multiply(xyz_pixels, [1 / 0.950456, 1.0, 1 / 1.088754])


            epsilon = 6 / 29
            linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
            exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
            fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4 / 29) * linear_mask + (xyz_normalized_pixels ** (1 / 3)) * exponential_mask


            # convert to lab
            fxfyfz_to_lab = tf.constant([
                #  l       a       b
                [0.0, 500.0, 0.0],  # fx
                [116.0, -500.0, 200.0],  # fy
                [0.0, 0.0, -200.0],  # fz
            ])
            lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])


        return tf.reshape(lab_pixels, tf.shape(srgb))
def lab_to_rgb(lab):
    with tf.name_scope("lab_to_rgb"):
        lab_pixels = tf.reshape(lab, [-1, 3])
         lab_pixels = tf.cast(lab_pixels, tf.float32)
        # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
        with tf.name_scope("cielab_to_xyz"):
            # convert to fxfyfz
            lab_to_fxfyfz = tf.constant([
                #   fx      fy        fz
                [1 / 116.0, 1 / 116.0, 1 / 116.0],  # l
                [1 / 500.0, 0.0, 0.0],  # a
                [0.0, 0.0, -1 / 200.0],  # b
            ])
            fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)


            # convert to xyz
            epsilon = 6 / 29
            linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
            exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
            xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4 / 29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask


            # denormalize for D65 white point
            xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
        with tf.name_scope("xyz_to_srgb"):
            xyz_to_rgb = tf.constant([
                #     r           g          b
                [3.2404542, -0.9692660, 0.0556434],  # x
                [-1.5371385, 1.8760108, -0.2040259],  # y
                [-0.4985314, 0.0415560, 1.0572252],  # z
            ])
            rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
            # avoid a slightly negative number messing up the conversion
            rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
            linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
            exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
            srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1 / 2.4) * 1.055) - 0.055) * exponential_mask
        return tf.reshape(srgb_pixels, tf.shape(lab))

这些代码代码较长,其中实现了RGB与LAB格式互换等相关辅助功能。由于这些代码其实现逻辑与算法主逻辑并无直接关联,因此读者对其有大概了解即可,无需投入过多精力。接下来实现数据集的加载,此次网络训练使用的数据集是place365,使用下面代码实现数据的解压和加载:


!tar -xvf '/content/drive/Shared drives/chenyi19820904.edu.us/place365_dataset/test-256.tar' #解压数据集

使用如上命令将数据集解压到硬盘上后,使用如下代码实现数据加载:

import os
import glob
import numpy as np
import tensorflow as tf
from scipy.misc import imread
from abc import abstractmethod
PLACES365_DATASET = 'places365'
class BaseDataset(): #将图片数据依次读入内存以便于用于训练网络
    def __init__(self, name, path, training=True, augment=True):
        self.name = name
        self.augment = augment and training
        self.training = training
        self.path = path
        self._data = []
    def __len__(self):
        return len(self.data)
    def __iter__(self):
        total = len(self)
        start = 0
        while start < total:
            item = self[start]
            start += 1
            yield item
        raise StopIteration
    def __getitem__(self, index):
        val = self.data[index]
        try:
            img = imread(val) if isinstance(val, str) else val


            # grayscale images
            if np.sum(img[:,:,0] - img[:,:,1]) == 0 and np.sum(img[:,:,0] - img[:,:,2]) == 0:
                return None


            if self.augment and np.random.binomial(1, 0.5) == 1:
                img = img[:, ::-1, :]
        except:
            img = None
        return img
    def generator(self, batch_size, recusrive=False):
        start = 0
        total = len(self)
        while True:
            while start < total:
                end = np.min([start + batch_size, total])
                items = []


                for ix in range(start, end):
                    item = self[ix]
                    if item is not None:
                        items.append(item)


                start = end
                yield items
                 if recusrive:
                start = 0
            else:
                raise StopIteration
    @property
    def data(self):
        if len(self._data) == 0:
            self._data = self.load()
            np.random.shuffle(self._data)
        return self._data
    @abstractmethod
    def load(self):
        return []
class Places365Dataset(BaseDataset):
    def __init__(self, path, training=True, augment=True):
        super(Places365Dataset, self).__init__(PLACES365_DATASET, path, training, augment)
    def load(self): #加载图片数据
        data = glob.glob(self.path + '/*.jpg', recursive=True)
        return data

准备好了数据和网络之和,使用如下代码驱动训练流程的进行:

class  Generator(tf.keras.Model):
    def  __init__(self, encoder_kernel, decoder_kernel):
        super(Generator, self).__init__()
        self.encoder_kernels = encoder_kernel#对应卷积层参数
        self.decoder_kernels = decoder_kernel #对应反卷积层参数
        self.kernel_size = 4
        self.output_channels = 3#最终输出RGB颜色图像图像
        self.left_size_layers = []
        self.right_size_layers = []
        self.last_layers = []
        self.create_network()
    def  create_network(self): #构建生成者网络
        for index, kernel in enumerate(self.encoder_kernels): #设立卷积层识别输入图像规律
            down_sample_layers = []
            down_sample_layers.append(tf.keras.layers.Conv2D(
                kernel_size = self.kernel_size,
                filters = kernel[0],
                strides = kernel[1],
                padding = 'same'
            ))
            down_sample_layers.append(tf.keras.layers.BatchNormalization())
            down_sample_layers.append(tf.keras.layers.LeakyReLU())
            self.left_size_layers.append(down_sample_layers)
        for index, kernel in enumerate(self.decoder_kernels):#设立反卷积层,实现像素点颜色赋值
            up_sample_layers = []
            up_sample_layers.append(tf.keras.layers.Conv2DTranspose(
                kernel_size = self.kernel_size,
                filters = kernel[0],
                strides = kernel[1],
                padding = 'same'
            ))
up_sample_layers.append(tf.keras.layers.BatchNormalization())
            up_sample_layers.append(tf.keras.layers.ReLU())
            self.right_size_layers.append(up_sample_layers)


        self.last_layers.append(tf.keras.layers.Conv2D(
                kernel_size = 1,
                filters = self.output_channels,
                strides = 1,
                padding = 'same',
                activation = 'tanh'
            ))#生成彩色图像
    def  call(self, x):
        x = tf.convert_to_tensor(x, dtype = tf.float32)
        left_layer_results = []
        for layers in self.left_size_layers:
            for layer in layers:
                x = layer(x)
            left_layer_results.append(x)
        left_layer_results.reverse()
        idx = 0
        x = left_layer_results[idx]
        for layers in self.right_size_layers:
            conresponding_left = left_layer_results[idx + 1] #将对应的卷积层输出直接提交给对应的反卷积层
            idx += 1
            for layer in layers:
                x = layer(x)
            x = tf.keras.layers.concatenate([x, conresponding_left])
        for layers in self.last_layers:
            x = layers(x)
        return x
    def  create_variables(self):#构造网络参数
        dummy1 = np.zeros((1, 256, 256, 1))
        self.call(dummy1)
class  Discriminator(tf.keras.Model):
    def  __init__(self, encoder_kernel):
        super(Discriminator, self).__init__()
        self.kernels = encoder_kernel #鉴别者网络卷积层参数
        self.discriminator_layers = []
        self.kernel_size = 4
        self.create_network()
    def  create_network(self):
        for index, kernel in enumerate(self.kernels):#构造卷积层识别输入图像规律
            self.discriminator_layers.append(tf.keras.layers.Conv2D(
                kernel_size = self.kernel_size,
                filters = kernel[0],
                strides = kernel[1],
                padding = 'same'
            ))
            self.discriminator_layers.append(tf.keras.layers.BatchNormalization())
            self.discriminator_layers.append(tf.keras.layers.LeakyReLU())
        self.discriminator_layers.append(tf.keras.layers.Conv2D(
                kernel_size = 4,
                 filters = 1,
                strides = 1,
                padding = 'same'
            ))#输出表明着色正确性的数值
    def  call(self, x):
        x = tf.convert_to_tensor(x, dtype = tf.float32)
        for  layer in  self.discriminator_layers:
            x = layer(x)
        return x
    def  create_variables(self):#生成鉴别者网络内部参数
        dummy1 = np.zeros((1, 256, 256, 1))
        dummy2 = np.zeros((1, 256, 256, 3))
        x = np.concatenate((dummy1, dummy2), axis = 3)
        self.call(x)            
class  ColorGAN:
    def  __init__(self):
        self.generator = None
        self.discriminator = None
        self.global_step = tf.Variable(0, dtype = tf.float32, trainable=False)
        self.create_generator_discriminator()
        self.data_generator = self.create_dataset(True)#加载训练数据集
        self.dataset_val = self.create_dataset(False)
        self.sample_generator = self.dataset_val.generator(8, True)
        self.learning_rate = tf.compat.v1.train.exponential_decay(
            learning_rate = 3e-4, global_step = self.global_step, 
            decay_steps = 1e-5, decay_rate = 0.1
        ) #生成者网络训练时需要学习率不断变化
        self.generator_optimizer = tf.optimizers.Adam(self.learning_rate, beta_1 = 0)
        self.discriminator_optimizer = tf.optimizers.Adam(3e-5, beta_1 = 0)
        self.batch_size = 16
        self.epochs = 5
        self.epoch = 0
        self.step = 0
        self.run_folder = "/content/drive/My Drive/ColorGAN/models/"
        #self.load_model()  #反注释该语句可实现网络参数直接加载
    def  create_generator_discriminator(self): #构造生成者和鉴别者
        generator_encoder = [ #第一个数值对应filter,第二个参数对应stride,kernel大小始终保持4
            (64, 1),
            (64, 2),
            (128, 2),
            (256, 2),
            (512, 2),
            (512, 2),
            (512, 2),
            (512, 2)
        ] #生成者网络卷积层参数
        generator_decoder = [
            (512, 2),
            (512, 2),
            (512, 2),
            (256, 2),
            (128, 2),
           (64, 2),
            (64, 2)
        ]#生成者网络反卷积层参数
        self.generator = Generator(generator_encoder, generator_decoder)
        self.generator.create_variables()
        discriminator_decoder = [
            (64, 2),
            (128, 2),
            (256, 2),
            (512, 1)
        ]#鉴别者网络卷积层参数
        self.discriminator = Discriminator(discriminator_decoder)
        self.discriminator.create_variables()
    def  train(self):
        for epoch in range(self.epochs): #加载训练数据训练生成者和鉴别者网络
            data_gen = self.data_generator.generator(16)
            for img in data_gen:
                img = np.array(img)
                self.train_discriminator(img)
                self.train_generator(img)
                self.train_generator(img)
                self.step += 1
                if  self.step  % 100 == 0: #显示训练效果
                    display.clear_output(wait = True)
                    self.sample()
                    self.save_model()
    def  train_discriminator(self, img_color):
        img_gray = tf.image.rgb_to_grayscale(img_color) #将图片转换为灰度图
        img_gray = tf.cast(img_gray, tf.float32)
        lab_color_img = preprocess(img_color, colorspace_in = COLORSPACE_RGB,
                                   colorspace_out = COLORSPACE_LAB)#转换为LAB格式
        gen_img = self.generator(img_gray)
        real_img = tf.concat([img_gray, lab_color_img], 3)
        fake_img = tf.concat([img_gray, gen_img], 3)
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            #如果输入数据是灰度图和原色彩图则让输出数值尽可能大,如果是生成者网络构造的色彩图则输出尽可能小
            tape.watch(self.discriminator.trainable_variables)
            discrinimator_real = self.discriminator(real_img, training = True)
            discriminator_fake = self.discriminator(fake_img, training = True)
            loss_real = tf.nn.sigmoid_cross_entropy_with_logits(logits = discrinimator_real,
                                                                labels = tf.ones_like(discrinimator_real))
            loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = discriminator_fake,
                                                                labels = tf.zeros_like(discriminator_fake))
            discriminator_loss = tf.reduce_mean(tf.reduce_mean(loss_real) + tf.reduce_mean(loss_fake))
        grads = tape.gradient(discriminator_loss, self.discriminator.trainable_variables)
        self.discriminator_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_variables))
    def  train_generator(self, img_color):
    img_gray = tf.image.rgb_to_grayscale(img_color)
        img_gray = tf.cast(img_gray, tf.float32)
        lab_color_img = preprocess(img_color, colorspace_in = COLORSPACE_RGB,
                                   colorspace_out = COLORSPACE_LAB)#转换为LAB格式
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            #让生成的彩色图和灰度图输入鉴别者网络后所得结果尽可能大
            tape.watch(self.generator.trainable_variables)
            gen_img = self.generator(tf.cast(img_gray, tf.float32), training = True)
            fake_img = tf.concat([img_gray, gen_img], 3)
            tape.watch(self.generator.trainable_variables)
            discriminator_fake =  self.discriminator(fake_img, training = True)
            loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = discriminator_fake,
                                                                labels = tf.ones_like(discriminator_fake) ) #尽可能通过鉴别者网络的审查
            generator_discriminator_loss = tf.reduce_mean(loss_fake)
            generator_content_loss = tf.reduce_mean(tf.abs(lab_color_img - gen_img)) * 100.0 #保证生成图片物体与输入图片物体尽可能在形状上相同
            generator_loss = generator_discriminator_loss + generator_content_loss
        grads = tape.gradient(generator_loss, self.generator.trainable_variables)
        self.generator_optimizer.apply_gradients(zip(grads, self.generator.trainable_variables))
    def  sample(self): #检验genertor的上色效果
        input_imgs = next(self.sample_generator)
        gray_imgs = tf.image.rgb_to_grayscale(input_imgs)
        gray_imgs = tf.cast(gray_imgs, tf.float32)
        fake_imgs = self.generator(gray_imgs, training = True)
        fake_imgs = postprocess(tf.convert_to_tensor(fake_imgs), colorspace_in = COLORSPACE_LAB,
                                 colorspace_out = COLORSPACE_RGB)
        img_show = stitch_images(gray_imgs, input_imgs, fake_imgs.numpy()) #将三张图片贴在一起
        imshow(np.array(img_show), "color_gan")
    def  save_model(self):  #保存当前网络参数    
        self.discriminator.save_weights(self.run_folder + "discriminator.h5")
        self.generator.save_weights(self.run_folder + "generator.h5")
    def  load_model(self):#加载网络参数
        self.discriminator.load_weights(self.run_folder + "discriminator.h5")
        self.generator.load_weights(self.run_folder + "generator.h5")
    def  create_dataset(self, training): #创建训练数据集
        return Places365Dataset(
            path= '/content/test_256/',
            training=training,
            augment= True)
import os
from IPython import display
gan = ColorGAN()
gan.train()#启动训练流程

运行代码后就能启动网络训练流程。该训练流程较为耗时,读者可以从随书目录中加载笔者已经训练好的网络以便直接查看训练结果。经过长时间训练后,笔者在体验上色效果时发现一个有趣现象,那就是有时上色后的图片比原来彩色图片具有更好的美感或艺术效果,以下是训练后网络实现的上色效果,如图4:



网络上色效果

展开阅读全文

页面更新:2024-05-11

标签:卷积   代码   灰度   图像   加载   参数   效果   数据   图片   网络   黑白图片

1 2 3 4 5

上滑加载更多 ↓
推荐阅读:
友情链接:
更多:

本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828  

© CopyRight 2020-2024 All Rights Reserved. Powered By 71396.com 闽ICP备11008920号-4
闽公网安备35020302034903号

Top