Tensorflow выгружает датасет в память
Пытаюсь обучить сверточную нейронную сеть, слишком быстро переполняется ОЗУ. Сейчас я отказался от GPU и пытаюсь начать обучение хотя бы на CPU. Сама модель запускается под WSL. У меня 64GB оперативной памяти и 20 из них заберает Win+WSL. При запуске обучение черерез мгновение съедается все оставшиеся 40гб оперативы и SSD грузит на 100% после чего WSL падает.
После небольшого эксперемента я понял что при подаче условно 5-и картинок, похожей проблеммы не наблюдается. Это может значить что выборка грузится не по одной картинке хотя я указал batch_size = 1, и buffer_size = 1. Не понимаю как заставить keras загружать только одну картинку за раз
Модель
model = keras.models.Sequential([
keras.layers.InputLayer(input_shape=(None, None, 1)),
keras.layers.Conv2D(64, (3, 3), padding="same", activation=keras.activations.relu),
keras.layers.Conv2D(64, (3, 3), padding="same", activation=keras.activations.relu, strides=2),
keras.layers.Conv2D(128, (3, 3), padding="same", activation=keras.activations.relu),
keras.layers.Conv2D(128, (3, 3), padding="same", activation=keras.activations.relu, strides=2),
keras.layers.Conv2D(256, (3, 3), padding="same", activation=keras.activations.relu),
keras.layers.Conv2D(256, (3, 3), padding="same", activation=keras.activations.relu, strides=2),
keras.layers.Conv2D(512, (3, 3), padding="same", activation=keras.activations.relu),
keras.layers.Conv2DTranspose(256, (3, 3), strides=2, padding="same", activation=keras.activations.relu),
keras.layers.Conv2D(256, (3, 3), padding="same", activation=keras.activations.relu),
keras.layers.Conv2DTranspose(128, (3, 3), strides=2, padding="same", activation=keras.activations.relu),
keras.layers.Conv2D(128, (3, 3), padding="same", activation=keras.activations.relu),
keras.layers.Conv2DTranspose(64, (3, 3), strides=2, padding="same", activation=keras.activations.relu),
keras.layers.Conv2D(2, (3, 3), padding="same", activation=keras.activations.tanh)
])
model.compile(optimizer='adam', loss='mse')
Подготовка данных и обучение
def get_train_data(image):
x = image[:, :, 0] / 100.0
x = np.expand_dims(x, axis=-1)
y = image[:, :, 1:] / 127.0
return (x, y)
def prepare_photos(photo):
img = io.imread(photo.decode("utf-8"))
to_resize = (np.array(img.shape[0:2]) // 8) * 8
img = keras.preprocessing.image.smart_resize(img, to_resize)
lab_img = color.rgb2lab(img / 255)
return lab_img
def generator(files):
for file in files:
img = prepare_photos(file)
x, y = get_train_data(img)
del img
gc.collect()
yield x, y
folder = "./photos"
batch_size = 1
total_files = 100
files = [os.path.join(folder, file).encode("utf-8") for file in os.listdir(folder)][:total_files]
output_signature = (
tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32),
tf.TensorSpec(shape=(None, None, 2), dtype=tf.float32)
)
dataset = tf.data.Dataset.from_generator(generator,args=[files], output_signature = output_signature)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=1)
model.fit(dataset, epochs=3)