Улучшение результатов автоэнкодера на базе mnist

Хочу написать нейросеть на keras, которая дорисовывает нижнюю часть изображения, зная верхнюю. Датасет - цифры mnist, 28x28, ЧБ. Я понимаю, что задача это сложная для нейросети и ошибка в любом случае будет довольно высокой, но мне не нужно, чтобы всё было идеально, главное не смазанно, как в итоге у меня и получилось. Привожу код обучения:

import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Flatten, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from glob import glob
from PIL import Image
import keras
import numpy as np

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.concatenate([x_train, x_test])
y_train = np.concatenate([y_train, y_test])
x_train = x_train[y_train == 1]

# Размеры изображения
img_width, img_height = 28, 28
input_img = Input(shape=(img_width//2, img_height, 1))  

# Энкодер
x = Conv2D(64, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
encoded = Flatten()(x)

# Декодер
x = Reshape(target_shape=(7, 14, 8*8))(encoded)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)  
# Создание модели
autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer=Adam(learning_rate=0.001), loss='binary_crossentropy')

x_train = x_train.astype('float32') / 255.0;

x_train = np.reshape(x_train, (x_train.shape[0], 28, 28, 1))

autoencoder.fit(x_train[:, :14], x_train[:, 14:], epochs=10, batch_size=128, shuffle=True)
autoencoder.save("model.h5")

Обучаю модель только на единицах, чтоб не путалась и потому что единица самая простая цифра. Но даже так результаты не очень впечатляющие:
Верхняя часть нормальная, а нижняя - тусклая и смазанная введите сюда описание изображения введите сюда описание изображения введите сюда описание изображения

В случае других цифр результат вообще плачевный. Буду благодарен всем, кто подскажет, какую архитектуру и гиперпараметры подобрать, а может, в чём-то другом дело... Можно было бы искусственно увеличивать яркость нижней части изображения, потому что она всегда тусклее получается, но какая-то кривость всё равно присутствует. Из того, что пробовал - шаг обучения меньше (не работает), архитектура посложнее (частично сработало), брать все цифры (результат хуже).


Ответы (1 шт):

Автор решения: gord1402

У вас входные данные имеют размерность 14 * 28 * 1 = 392, тогда как закодированный вектор 6272, то есть в нём больше информации чем в начальном изображении, но автоэнкодер для чёткого результата должен сжать, а не увеличить.

Добавим dense дабы уменьшить закодированный вектор.

x = Conv2D(64, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = Flatten()(x)
encoded = Dense(units=64, activation='relu')(x)

# Декодер
x = Dense(units=x.shape[1], activation='relu')(encoded)
x = Reshape(target_shape=(7, 14, 8*8))(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)  

Обучим на 10 эпохах.

Дабы убрать размытие можно также возвести результат в степень и нормировать, вот код для отображения:

num_images = 10
size = 6

plt.figure(figsize=(size, size / 2 * num_images))

for i in range(num_images):
    # Выбираем случайно
    idx = np.random.choice(len(x_train[y_train == (i % 10)]), 1, replace=False)[0]
    data = x_train[y_train == (i % 10)][idx]
    
    # Убераем размытие
    bottom_half = autoencoder(np.array([data[:14]]))[0] ** 2.5
    bottom_half /= np.max(bottom_half)  
    
    # Объединяем половинки
    combined_image = np.vstack((data[:14], bottom_half))
    
    plt.subplot(num_images, 2, 2 * i + 1)
    plt.imshow(data, cmap='gray')
    plt.title(f"Original")
    plt.axis('off')
    
    plt.subplot(num_images, 2, 2 * i + 2)
    plt.imshow(combined_image, cmap='gray')
    plt.title(f"Reconstructed")
    plt.axis('off')

plt.tight_layout()
plt.show()

Итог:

введите сюда описание изображения

→ Ссылка