У меня возникает проблема при model.fit

При попытке тренировать модель вылезает такая ошибка:

'''---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
Cell In[61], line 11
      7 batch_size = 64
      9 model.compile(loss=tf.keras.losses.BinaryCrossentropy(), optimizer=RMSprop(learning_rate=0.001), metrics=['accuracy'])
---> 11 model_history = model.fit(train_dataset,
     12                           epochs=epochs,
     13                           #steps_per_epoch=dataset_size // batch_size,
     14                           validation_data=val_dataset)

File /opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    120     filtered_tb = _process_traceback_frames(e.__traceback__)
    121     # To get the full stack trace, call:
    122     # `keras.config.disable_traceback_filtering()`
--> 123     raise e.with_traceback(filtered_tb) from None
    124 finally:
    125     del filtered_tb

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/execute.py:53, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     51 try:
     52   ctx.ensure_initialized()
---> 53   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     54                                       inputs, attrs, num_outputs)
     55 except core._NotOkStatusException as e:
     56   if name is not None:

InvalidArgumentError: Graph execution error:

Detected at node IteratorGetNext defined at (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code

  File "/opt/conda/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>

  File "/opt/conda/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 701, in start

  File "/opt/conda/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start

  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 534, in dispatch_queue

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 523, in process_one

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 429, in dispatch_shell

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 767, in execute_request

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 429, in do_execute

  File "/opt/conda/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell

  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3106, in _run_cell

  File "/opt/conda/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3311, in run_cell_async

  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3493, in run_ast_nodes

  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "/tmp/ipykernel_33/1529170660.py", line 11, in <module>

  File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 118, in error_handler

  File "/opt/conda/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 323, in fit

  File "/opt/conda/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 116, in one_step_on_iterator

Incompatible shapes at component 0: expected [?,768,768,3] but got [64,1,768,768,3].
     [[{{node IteratorGetNext}}]] [Op:__inference_one_step_on_iterator_16980]'''

как я понимаю проблема главная в: Incompatible shapes at component 0: expected [?,768,768,3] but got [64,1,768,768,3]. [[{{node IteratorGetNext}}]] [Op:__inference_one_step_on_iterator_16980] но почему моя модель ожидает такую форму?

P.S. У меня unet модель

def unet():
    inputs = tf.keras.layers.Input(shape=(768,768,3))
    encoder_output, convs = encoder(inputs)
    
    bottle_neck = bottleneck(encoder_output)
    
    outputs = decoder(bottle_neck, convs)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    
    return model

Вот данные, что поступают в модель

train_dataset.element_spec
(TensorSpec(shape=(None, 768, 768, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 768, 768, 2), dtype=tf.float32, name=None))

val_dataset.element_spec
(TensorSpec(shape=(None, 768, 768, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 768, 768, 2), dtype=tf.float32, name=None))

Ответы (0 шт):