У меня возникает проблема при model.fit
При попытке тренировать модель вылезает такая ошибка:
'''---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
Cell In[61], line 11
7 batch_size = 64
9 model.compile(loss=tf.keras.losses.BinaryCrossentropy(), optimizer=RMSprop(learning_rate=0.001), metrics=['accuracy'])
---> 11 model_history = model.fit(train_dataset,
12 epochs=epochs,
13 #steps_per_epoch=dataset_size // batch_size,
14 validation_data=val_dataset)
File /opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs)
120 filtered_tb = _process_traceback_frames(e.__traceback__)
121 # To get the full stack trace, call:
122 # `keras.config.disable_traceback_filtering()`
--> 123 raise e.with_traceback(filtered_tb) from None
124 finally:
125 del filtered_tb
File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/execute.py:53, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
51 try:
52 ctx.ensure_initialized()
---> 53 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
54 inputs, attrs, num_outputs)
55 except core._NotOkStatusException as e:
56 if name is not None:
InvalidArgumentError: Graph execution error:
Detected at node IteratorGetNext defined at (most recent call last):
File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
File "/opt/conda/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
File "/opt/conda/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 701, in start
File "/opt/conda/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 534, in dispatch_queue
File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 523, in process_one
File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 429, in dispatch_shell
File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
File "/opt/conda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 429, in do_execute
File "/opt/conda/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell
File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3106, in _run_cell
File "/opt/conda/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3311, in run_cell_async
File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3493, in run_ast_nodes
File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code
File "/tmp/ipykernel_33/1529170660.py", line 11, in <module>
File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 118, in error_handler
File "/opt/conda/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 323, in fit
File "/opt/conda/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 116, in one_step_on_iterator
Incompatible shapes at component 0: expected [?,768,768,3] but got [64,1,768,768,3].
[[{{node IteratorGetNext}}]] [Op:__inference_one_step_on_iterator_16980]'''
как я понимаю проблема главная в: Incompatible shapes at component 0: expected [?,768,768,3] but got [64,1,768,768,3]. [[{{node IteratorGetNext}}]] [Op:__inference_one_step_on_iterator_16980] но почему моя модель ожидает такую форму?
P.S. У меня unet модель
def unet():
inputs = tf.keras.layers.Input(shape=(768,768,3))
encoder_output, convs = encoder(inputs)
bottle_neck = bottleneck(encoder_output)
outputs = decoder(bottle_neck, convs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
Вот данные, что поступают в модель
train_dataset.element_spec
(TensorSpec(shape=(None, 768, 768, 3), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 768, 768, 2), dtype=tf.float32, name=None))
val_dataset.element_spec
(TensorSpec(shape=(None, 768, 768, 3), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 768, 768, 2), dtype=tf.float32, name=None))