Раздельное обновление весов нейронной сети в двух ветках

Была построена нейронная сеть VAE, но с двумя параллельными ветками. Представляю код кодера:

# Реализаця кодера
encoder_input = layers.Input(shape = input_dim, name="Input_Encoder");
print(encoder_input.shape)
x = tf.keras.layers.Lambda(CalculateFFT, name = 'Calcualte_FFT')(encoder_input)
x_Re = tf.keras.layers.Conv1D(filters = 64, 
                              kernel_size = 2, 
                              activation="tanh", 
                              name = 'Encoder_Conv_1_Re', 
                              padding = 'same')(tf.reshape(x[0], (1,1,1024)))
x_Re = tf.keras.layers.Conv1D(filters = 128, 
                              kernel_size = 2, 
                              activation="tanh", 
                              name = 'Encoder_Conv_2_Re', 
                              padding = 'same')(x_Re)
x_Re = tf.keras.layers.Conv1D(filters = 256, 
                              kernel_size = 2, 
                              activation="tanh", 
                              name = 'Encoder_Conv_3_Re', 
                              padding = 'same')(x_Re)
x_Re = tf.keras.layers.Conv1D(filters = 512, 
                              kernel_size = 2, 
                              activation="tanh", 
                              name = 'Encoder_Conv_4_Re', 
                              padding = 'same')(x_Re)
print(x_Re.shape)
x_Re = tf.keras.layers.Flatten(name = 'Flatten_Re')(x_Re)
print(x_Re.shape)
x_Re = tf.keras.layers.Dense(512)(x_Re)
print(x_Re.shape)
z_mean_Re = layers.Dense(latent_dim, name="z_mean_Re")(x_Re)

z_log_var_Re = layers.Dense(latent_dim, name="z_log_var")(x_Re)

z_Re = Sampling()([z_mean_Re, z_log_var_Re])

x_Im = tf.keras.layers.Conv1D(filters = 64,
                              kernel_size = 5, 
                              activation="tanh", 
                              name = 'Encoder_Conv_1_Im', 
                              padding = 'same')(tf.reshape(x[1], (1,1,1024)))

x_Im = tf.keras.layers.Conv1D(filters = 128,
                              kernel_size = 5, 
                              activation="tanh", 
                              name = 'Encoder_Conv_2_Im', 
                              padding = 'same')(x_Im)

x_Im = tf.keras.layers.Conv1D(filters = 256,
                              kernel_size = 5, 
                              activation="tanh", 
                              name = 'Encoder_Conv_3_Im', 
                              padding = 'same')(x_Im)

x_Im = tf.keras.layers.Conv1D(filters = 512,
                              kernel_size = 5, 
                              activation="tanh", 
                              name = 'Encoder_Conv_4_Im', 
                              padding = 'same')(x_Im)

print(x_Im.shape)
x_Im = tf.keras.layers.Flatten(name = 'Flatten')(x_Im)
print(x_Im.shape)
x_Im = tf.keras.layers.Dense(512)(x_Im)
print(x_Im.shape)
print(x_Re.shape)
z_mean_Im = layers.Dense(latent_dim, name="z_mean_Im")(x_Im)

z_log_var_Im = layers.Dense(latent_dim, name="z_log_var_Im")(x_Im)

z_Im = Sampling()([z_mean_Im, z_log_var_Im])

encoder = keras.Model(encoder_input, [z_mean_Re, z_log_var_Re, z_Re, z_mean_Im, z_log_var_Im, z_Im], name="encoder")
encoder.summary()
plot_model(encoder, to_file='model_plot.png', show_shapes=True, show_layer_names=True)

Представляю код декодера:

decoder_input_Re = tf.keras.Input(shape=(latent_dim))
decoder_input_Im = tf.keras.Input(shape=(latent_dim))
x_Re =  tf.keras.layers.Dense(int(latent_dim), activation="tanh")(decoder_input_Re)
print('Re: ',x_Re.shape)
x_Re = layers.Reshape((1,1,int(latent_dim)))(x_Re)
print('Re: ',x_Re.shape)
x_Re = tf.keras.layers.Conv2D(512,1, activation="tanh")(x_Re)
print('Re: ',x_Re.shape)
x_Re = tf.keras.layers.Conv2D(1024,1, activation="tanh")(x_Re)
print('Re: ',x_Re.shape)
x_Re = layers.Reshape((1,int(latent_dim * num_compress)))(x_Re)
print('Re: ',x_Re.shape)
x_Re = tf.keras.layers.Flatten(name = 'Flatten_Re')(x_Re)
print('Re: ',x_Re.shape)
print('--------------------------------')
x_Im =  tf.keras.layers.Dense(int(latent_dim), activation="tanh")(decoder_input_Im)
print('Im: ',x_Im.shape)
x_Im = layers.Reshape((1,1,int(latent_dim)))(x_Im)
print('Im: ',x_Im.shape)
x_Im = tf.keras.layers.Conv2D(512,1, activation="tanh")(x_Im)
print('Im: ',x_Im.shape)
x_Im = tf.keras.layers.Conv2D(1024,1, activation="tanh")(x_Im)
print('Im: ',x_Im.shape)
x_Im = layers.Reshape((1,int(latent_dim * num_compress)))(x_Im)
print('Im: ',x_Im.shape)
x_Im = tf.keras.layers.Flatten(name = 'Flatten_Im')(x_Im)
print('Im: ',x_Im.shape)
print('--------------------------------')
print('Out')
decoder_outputs = tf.keras.layers.Lambda(RecoverSignal, name = 'RecoverSignal')([[x_Re],[x_Im]])
print(decoder_outputs.shape)
decoder_outputs = decoder_outputs[0]
print(decoder_outputs.shape)
decoder = keras.Model([decoder_input_Re,decoder_input_Im], decoder_outputs, name="decoder")
decoder.summary()
plot_model(decoder, to_file='deocder.png', show_shapes=True, show_layer_names=True)

Как видно, у меня имеется две параллельных ветки, для действительной и мнимой части сигнала. В общем вопрос заключается в том, как производить обновление весов при обучении раздельно между веткой обучения для мнимой и действительной части. Прикладываю стандартный код обучения для VAE:

class VAE(keras.Model):
    
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        
#       Общая ошибка складывающаяся из ошибки востановления и ошибки КЛ
        self.total_loss_tracker_Re = keras.metrics.Mean(name="total_loss_Re")
        self.total_loss_tracker_Im = keras.metrics.Mean(name="total_loss_Im")
        
#       Ошибка востановления
        self.reconstruction_loss_tracker_Re = keras.metrics.Mean(name="reconstruction_loss_Re")
        self.reconstruction_loss_tracker_Im = keras.metrics.Mean(name="reconstruction_loss_Im")
        
#       Ошибка КЛ
        self.kl_loss_tracker_Re = keras.metrics.Mean(name="kl_loss_Re")
        self.kl_loss_tracker_Im = keras.metrics.Mean(name="kl_loss_Im")      
        
        
        self.beta = tf.Variable(0.000000)
        
#   Отслеживание метрик
    @property
    def metrics(self):
        return [
            self.total_loss_tracker_Re,
            self.total_loss_tracker_Im,
            
            self.reconstruction_loss_tracker_Re,
            self.reconstruction_loss_tracker_Im,
            
            self.kl_loss_tracker_Re,
            self.kl_loss_tracker_Im,
        ]
    
    def set_beta(self, new_beta):
        self.beta.assign(new_beta)   
    
#   Обучение нейронной сети
    @tf.function
    def train_step(self, data):
        
        with tf.GradientTape(persistent=True) as tape:
            z_mean_Re, z_log_var_Re, z_Re, z_mean_Im, z_log_var_Im, z_Im, = self.encoder(data)
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
            reconstruction_Re = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )

            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            
#             kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            kl_loss = (tf.reduce_sum(kl_loss, axis=1))
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
            total_loss = tf.cond(self.beta <= 0, 
                                lambda: (reconstruction_loss + kl_loss),
                                lambda:(reconstruction_loss + self.beta * kl_loss))



#             total_loss = reconstruction_loss + kl_loss
#             tf.print(beta)
            
        
        gradients_of_enc = tape.gradient(total_loss, self.encoder.trainable_variables)
        gradients_of_dec = tape.gradient(total_loss, self.decoder.trainable_variables)
        
        
        
        
        self.optimizer.apply_gradients(zip(gradients_of_enc, self.encoder.trainable_variables))
        self.optimizer.apply_gradients(zip(gradients_of_dec, self.decoder.trainable_variables))
        
        
        
#         grads = tape.gradient(total_loss, self.trainable_weights)
#         self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        
        self.total_loss_tracker.update_state(total_loss)
            
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
            
        self.kl_loss_tracker.update_state(kl_loss)  
#         tf.print(tf.cast(self.beta, tf.float32))
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

введите сюда описание изображения


Ответы (0 шт):