Раздельное обновление весов нейронной сети в двух ветках
Была построена нейронная сеть VAE, но с двумя параллельными ветками. Представляю код кодера:
# Реализаця кодера
encoder_input = layers.Input(shape = input_dim, name="Input_Encoder");
print(encoder_input.shape)
x = tf.keras.layers.Lambda(CalculateFFT, name = 'Calcualte_FFT')(encoder_input)
x_Re = tf.keras.layers.Conv1D(filters = 64,
kernel_size = 2,
activation="tanh",
name = 'Encoder_Conv_1_Re',
padding = 'same')(tf.reshape(x[0], (1,1,1024)))
x_Re = tf.keras.layers.Conv1D(filters = 128,
kernel_size = 2,
activation="tanh",
name = 'Encoder_Conv_2_Re',
padding = 'same')(x_Re)
x_Re = tf.keras.layers.Conv1D(filters = 256,
kernel_size = 2,
activation="tanh",
name = 'Encoder_Conv_3_Re',
padding = 'same')(x_Re)
x_Re = tf.keras.layers.Conv1D(filters = 512,
kernel_size = 2,
activation="tanh",
name = 'Encoder_Conv_4_Re',
padding = 'same')(x_Re)
print(x_Re.shape)
x_Re = tf.keras.layers.Flatten(name = 'Flatten_Re')(x_Re)
print(x_Re.shape)
x_Re = tf.keras.layers.Dense(512)(x_Re)
print(x_Re.shape)
z_mean_Re = layers.Dense(latent_dim, name="z_mean_Re")(x_Re)
z_log_var_Re = layers.Dense(latent_dim, name="z_log_var")(x_Re)
z_Re = Sampling()([z_mean_Re, z_log_var_Re])
x_Im = tf.keras.layers.Conv1D(filters = 64,
kernel_size = 5,
activation="tanh",
name = 'Encoder_Conv_1_Im',
padding = 'same')(tf.reshape(x[1], (1,1,1024)))
x_Im = tf.keras.layers.Conv1D(filters = 128,
kernel_size = 5,
activation="tanh",
name = 'Encoder_Conv_2_Im',
padding = 'same')(x_Im)
x_Im = tf.keras.layers.Conv1D(filters = 256,
kernel_size = 5,
activation="tanh",
name = 'Encoder_Conv_3_Im',
padding = 'same')(x_Im)
x_Im = tf.keras.layers.Conv1D(filters = 512,
kernel_size = 5,
activation="tanh",
name = 'Encoder_Conv_4_Im',
padding = 'same')(x_Im)
print(x_Im.shape)
x_Im = tf.keras.layers.Flatten(name = 'Flatten')(x_Im)
print(x_Im.shape)
x_Im = tf.keras.layers.Dense(512)(x_Im)
print(x_Im.shape)
print(x_Re.shape)
z_mean_Im = layers.Dense(latent_dim, name="z_mean_Im")(x_Im)
z_log_var_Im = layers.Dense(latent_dim, name="z_log_var_Im")(x_Im)
z_Im = Sampling()([z_mean_Im, z_log_var_Im])
encoder = keras.Model(encoder_input, [z_mean_Re, z_log_var_Re, z_Re, z_mean_Im, z_log_var_Im, z_Im], name="encoder")
encoder.summary()
plot_model(encoder, to_file='model_plot.png', show_shapes=True, show_layer_names=True)
Представляю код декодера:
decoder_input_Re = tf.keras.Input(shape=(latent_dim))
decoder_input_Im = tf.keras.Input(shape=(latent_dim))
x_Re = tf.keras.layers.Dense(int(latent_dim), activation="tanh")(decoder_input_Re)
print('Re: ',x_Re.shape)
x_Re = layers.Reshape((1,1,int(latent_dim)))(x_Re)
print('Re: ',x_Re.shape)
x_Re = tf.keras.layers.Conv2D(512,1, activation="tanh")(x_Re)
print('Re: ',x_Re.shape)
x_Re = tf.keras.layers.Conv2D(1024,1, activation="tanh")(x_Re)
print('Re: ',x_Re.shape)
x_Re = layers.Reshape((1,int(latent_dim * num_compress)))(x_Re)
print('Re: ',x_Re.shape)
x_Re = tf.keras.layers.Flatten(name = 'Flatten_Re')(x_Re)
print('Re: ',x_Re.shape)
print('--------------------------------')
x_Im = tf.keras.layers.Dense(int(latent_dim), activation="tanh")(decoder_input_Im)
print('Im: ',x_Im.shape)
x_Im = layers.Reshape((1,1,int(latent_dim)))(x_Im)
print('Im: ',x_Im.shape)
x_Im = tf.keras.layers.Conv2D(512,1, activation="tanh")(x_Im)
print('Im: ',x_Im.shape)
x_Im = tf.keras.layers.Conv2D(1024,1, activation="tanh")(x_Im)
print('Im: ',x_Im.shape)
x_Im = layers.Reshape((1,int(latent_dim * num_compress)))(x_Im)
print('Im: ',x_Im.shape)
x_Im = tf.keras.layers.Flatten(name = 'Flatten_Im')(x_Im)
print('Im: ',x_Im.shape)
print('--------------------------------')
print('Out')
decoder_outputs = tf.keras.layers.Lambda(RecoverSignal, name = 'RecoverSignal')([[x_Re],[x_Im]])
print(decoder_outputs.shape)
decoder_outputs = decoder_outputs[0]
print(decoder_outputs.shape)
decoder = keras.Model([decoder_input_Re,decoder_input_Im], decoder_outputs, name="decoder")
decoder.summary()
plot_model(decoder, to_file='deocder.png', show_shapes=True, show_layer_names=True)
Как видно, у меня имеется две параллельных ветки, для действительной и мнимой части сигнала. В общем вопрос заключается в том, как производить обновление весов при обучении раздельно между веткой обучения для мнимой и действительной части. Прикладываю стандартный код обучения для VAE:
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
# Общая ошибка складывающаяся из ошибки востановления и ошибки КЛ
self.total_loss_tracker_Re = keras.metrics.Mean(name="total_loss_Re")
self.total_loss_tracker_Im = keras.metrics.Mean(name="total_loss_Im")
# Ошибка востановления
self.reconstruction_loss_tracker_Re = keras.metrics.Mean(name="reconstruction_loss_Re")
self.reconstruction_loss_tracker_Im = keras.metrics.Mean(name="reconstruction_loss_Im")
# Ошибка КЛ
self.kl_loss_tracker_Re = keras.metrics.Mean(name="kl_loss_Re")
self.kl_loss_tracker_Im = keras.metrics.Mean(name="kl_loss_Im")
self.beta = tf.Variable(0.000000)
# Отслеживание метрик
@property
def metrics(self):
return [
self.total_loss_tracker_Re,
self.total_loss_tracker_Im,
self.reconstruction_loss_tracker_Re,
self.reconstruction_loss_tracker_Im,
self.kl_loss_tracker_Re,
self.kl_loss_tracker_Im,
]
def set_beta(self, new_beta):
self.beta.assign(new_beta)
# Обучение нейронной сети
@tf.function
def train_step(self, data):
with tf.GradientTape(persistent=True) as tape:
z_mean_Re, z_log_var_Re, z_Re, z_mean_Im, z_log_var_Im, z_Im, = self.encoder(data)
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
reconstruction_Re = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
tf.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
# kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
kl_loss = (tf.reduce_sum(kl_loss, axis=1))
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
total_loss = tf.cond(self.beta <= 0,
lambda: (reconstruction_loss + kl_loss),
lambda:(reconstruction_loss + self.beta * kl_loss))
# total_loss = reconstruction_loss + kl_loss
# tf.print(beta)
gradients_of_enc = tape.gradient(total_loss, self.encoder.trainable_variables)
gradients_of_dec = tape.gradient(total_loss, self.decoder.trainable_variables)
self.optimizer.apply_gradients(zip(gradients_of_enc, self.encoder.trainable_variables))
self.optimizer.apply_gradients(zip(gradients_of_dec, self.decoder.trainable_variables))
# grads = tape.gradient(total_loss, self.trainable_weights)
# self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
# tf.print(tf.cast(self.beta, tf.float32))
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
