как визуализировать выход модели

Мой 'z_sample' должен иметь форму (1, 10) согласно модели, но он приобретает форму (1, 2) и выдает следующую ошибку:

ValueError: Input 0 of layer "dense_1" is incompatible with the layer: expected axis -1of input shape to have value 10, but received input with shape (1, 2)

Если я правильно понимаю, ошибка кроется здесь:

for i, yi in enumerate(grid_y):
    for j, xi in enumerate(grid_x):
      z_sample = np.array([[xi, yi]])
      x_decoded = vae_decoder(z_sample)

Подскажите пожалуйста, как правильно записать, чтоб соответствовать модели? Привожу функцию визуализации для ссылки:

def plot_latent_space(n=10, figsize=15):
  digit_size = 28
  scale = 1.5
  figure = np.zeros((digit_size * n, digit_size * n))

  grid_x = np.linspace(-scale, scale, n)
  grid_y = np.linspace(-scale, scale, n)[::-1]

  for i, yi in enumerate(grid_y):
    for j, xi in enumerate(grid_x):
      z_sample = np.array([[xi, yi]])
      x_decoded = vae_decoder(z_sample)
      digit = tf.reshape(x_decoded[0], shape=(digit_size, digit_size))
      figure[
             i * digit_size : (i + 1) * digit_size,
             j * digit_size : (j + 1) * digit_size,
      ] = digit

  plt.figure(figsize=(figsize, figsize))
  start_range = digit_size // 2
  end_range = n * digit_size + start_range
  pixel_range = np.arange(start_range, end_range, digit_size)
  sample_range_x = np.round(grid_x, 1)
  sample_range_y = np.round(grid_y, 1)
  plt.xticks(pixel_range, sample_range_x)
  plt.yticks(pixel_range, sample_range_y)
  plt.xlabel("z[0]")
  plt.ylabel("z[1]")
  plt.imshow(figure, cmap="Greys_r")
  plt.show()

Прилагаю вывод модели:

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 28, 28, 6)    156         ['input_1[0][0]']                
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 14, 14, 6)    0           ['conv2d[0][0]']                 
                                                                                                  
 conv2d_1 (Conv2D)              (None, 10, 10, 16)   2416        ['max_pooling2d[0][0]']          
                                                                                                  
 max_pooling2d_1 (MaxPooling2D)  (None, 5, 5, 16)    0           ['conv2d_1[0][0]']               
                                                                                                  
 flatten (Flatten)              (None, 400)          0           ['max_pooling2d_1[0][0]']        
                                                                                                  
 dense (Dense)                  (None, 20)           8020        ['flatten[0][0]']                
                                                                                                  
 tf.split (TFOpLambda)          [(None, 10),         0           ['dense[0][0]']                  
                                 (None, 10)]                                                      
                                                                                                  
 tf.math.multiply (TFOpLambda)  (None, 10)           0           ['tf.split[0][1]']               
                                                                                                  
 tf.compat.v1.shape (TFOpLambda  (2,)                0           ['tf.split[0][0]']               
 )                                                                                                
                                                                                                  
 tf.math.exp (TFOpLambda)       (None, 10)           0           ['tf.math.multiply[0][0]']       
                                                                                                  
 tf.random.normal (TFOpLambda)  (None, 10)           0           ['tf.compat.v1.shape[0][0]']     
                                                                                                  
 tf.math.multiply_1 (TFOpLambda  (None, 10)          0           ['tf.math.exp[0][0]',            
 )                                                                'tf.random.normal[0][0]']       
                                                                                                  
 tf.__operators__.add (TFOpLamb  (None, 10)          0           ['tf.split[0][0]',               
 da)                                                              'tf.math.multiply_1[0][0]']     
                                                                                                  
 dense_1 (Dense)                (None, 400)          4400        ['tf.__operators__.add[0][0]']   
                                                                                                  
 reshape (Reshape)              (None, 5, 5, 16)     0           ['dense_1[0][0]']                
                                                                                                  
 up_sampling2d (UpSampling2D)   multiple             0           ['reshape[0][0]',                
                                                                  'conv2d_transpose[0][0]']       
                                                                                                  
 conv2d_transpose (Conv2DTransp  (None, 14, 14, 16)  6416        ['up_sampling2d[0][0]']          
 ose)                                                                                             
                                                                                                  
 conv2d_transpose_1 (Conv2DTran  (None, 28, 28, 6)   2406        ['up_sampling2d[1][0]']          
 spose)                                                                                           
                                                                                                  
 conv2d_transpose_2 (Conv2DTran  (None, 28, 28, 1)   55          ['conv2d_transpose_1[0][0]']     
 spose)                                                                                           
                                                                                                  
==================================================================================================
Total params: 23,869
Trainable params: 23,869
Non-trainable params: 0
__________________________________________________________________________________________________

Ответы (0 шт):