Входной слой и веса имеют разную размерность
Что нужно изменить размерность у входного слоя и весов которыми он связан? Если я не изменяю размерность x в цикле, то нейросеть выдаёт вектор размером (28, 10) на выходе
import tensorflow as tf
from keras.datasets import mnist
class NN(tf.Module):
def __init__(self):
super().__init__()
self.fl_init = True
def __call__(self, x):
if self.fl_init:
input_neurons = 784
neurons_layer1 = 16
neurons_layer2 = 16
neurons_layer3 = 10
self.w0 = tf.Variable(tf.random.truncated_normal((input_neurons, neurons_layer1), dtype=tf.float32))
self.w1 = tf.Variable(tf.random.truncated_normal((neurons_layer1, neurons_layer2), dtype=tf.float32))
self.w2 = tf.Variable(tf.random.truncated_normal((neurons_layer2, neurons_layer3), dtype=tf.float32))
self.b0 = tf.Variable(tf.zeros([neurons_layer1], dtype=tf.float32))
self.b1 = tf.Variable(tf.zeros([neurons_layer2], dtype=tf.float32))
self.b2 = tf.Variable(tf.zeros([neurons_layer3], dtype=tf.float32))
self.fl_init = False
x = tf.cast(x, tf.float32)
print(x.shape) # (784)
print(self.w0.shape) # (784, 16)
layer = x @ self.w0 + self.b0 # ошибка
layer = layer @ self.w1 + self.b1
layer = layer @ self.w2 + self.b2
print(layer.shape)
return tf.nn.softmax(layer)
cross_entropy = lambda y_true, y_pred: tf.reduce_mean(tf.losses.categorical_crossentropy(y_true, y_pred))
model = NN()
opt = tf.optimizers.Adam(learning_rate=0.0001)
EPOCHS = 1000
BATCH_SIZE = 32
(x_train, y_train), (__, ___) = mnist.load_data()
x_train = x_train / 255
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE)
for n in range(EPOCHS):
for x_batch, y_batch in train_dataset:
loss_sum = 0
for x, y in zip(x_batch, y_batch):
y_array = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
y_array[y] = 1
y = y_array
x = tf.reshape(x, 784)
with tf.GradientTape() as tape:
f_loss = cross_entropy(y, model(x))
loss_sum += f_loss
grads = tape.gradient(f_loss, model.trainable_variables)
opt.apply_gradients(zip(grads, model.trainable_variables))
print(loss_sum)