Tensorflow выгружает датасет в память

Пытаюсь обучить сверточную нейронную сеть, слишком быстро переполняется ОЗУ. Сейчас я отказался от GPU и пытаюсь начать обучение хотя бы на CPU. Сама модель запускается под WSL. У меня 64GB оперативной памяти и 20 из них заберает Win+WSL. При запуске обучение черерез мгновение съедается все оставшиеся 40гб оперативы и SSD грузит на 100% после чего WSL падает.

После небольшого эксперемента я понял что при подаче условно 5-и картинок, похожей проблеммы не наблюдается. Это может значить что выборка грузится не по одной картинке хотя я указал batch_size = 1, и buffer_size = 1. Не понимаю как заставить keras загружать только одну картинку за раз

Модель

model = keras.models.Sequential([
    keras.layers.InputLayer(input_shape=(None, None, 1)),
    keras.layers.Conv2D(64, (3, 3), padding="same", activation=keras.activations.relu),
    keras.layers.Conv2D(64, (3, 3), padding="same", activation=keras.activations.relu, strides=2),
    keras.layers.Conv2D(128, (3, 3), padding="same", activation=keras.activations.relu),
    keras.layers.Conv2D(128, (3, 3), padding="same", activation=keras.activations.relu, strides=2),
    keras.layers.Conv2D(256, (3, 3), padding="same", activation=keras.activations.relu),
    keras.layers.Conv2D(256, (3, 3), padding="same", activation=keras.activations.relu, strides=2),
    keras.layers.Conv2D(512, (3, 3), padding="same", activation=keras.activations.relu),
    keras.layers.Conv2DTranspose(256, (3, 3), strides=2, padding="same", activation=keras.activations.relu),
    keras.layers.Conv2D(256, (3, 3), padding="same", activation=keras.activations.relu),
    keras.layers.Conv2DTranspose(128, (3, 3), strides=2, padding="same", activation=keras.activations.relu),
    keras.layers.Conv2D(128, (3, 3), padding="same", activation=keras.activations.relu),
    keras.layers.Conv2DTranspose(64, (3, 3), strides=2, padding="same", activation=keras.activations.relu),
    keras.layers.Conv2D(2, (3, 3), padding="same", activation=keras.activations.tanh)
])

model.compile(optimizer='adam', loss='mse')

Подготовка данных и обучение

def get_train_data(image):
    x = image[:, :, 0] / 100.0
    x = np.expand_dims(x, axis=-1)
    y = image[:, :, 1:] / 127.0
    return (x, y)

def prepare_photos(photo):
    img = io.imread(photo.decode("utf-8"))       
    to_resize = (np.array(img.shape[0:2]) // 8) * 8
    img = keras.preprocessing.image.smart_resize(img, to_resize)
    lab_img = color.rgb2lab(img / 255)
    return lab_img

def generator(files):
    for file in files:
        img = prepare_photos(file)
        x, y = get_train_data(img)
        del img
        gc.collect()
        yield x, y
        
folder = "./photos"
batch_size = 1
total_files = 100

files = [os.path.join(folder, file).encode("utf-8") for file in os.listdir(folder)][:total_files]

output_signature = (
    tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32),
    tf.TensorSpec(shape=(None, None, 2), dtype=tf.float32)
)

dataset = tf.data.Dataset.from_generator(generator,args=[files], output_signature = output_signature)

dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=1)

model.fit(dataset, epochs=3)

Ответы (0 шт):