не работает tf.data.experimental.service.distribute с TPU? или я неправильно делаю?
Хочу отдельно запустить подготовку данных на локальной машине, а нейросеть на TPU из обычного Colab Pro. Запускаю TPU из локальной машины:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver().connect() # TPU detection
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.TPUStrategy(tpu)
Запускаю один или два рабочих процесса на локальной машине:
dispatcher = tf.data.experimental.service.DispatchServer()
dispatcher_address = dispatcher.target.split("://")[1]
# Start two workers
workers = [
tf.data.experimental.service.WorkerServer(
tf.data.experimental.service.WorkerConfig(
dispatcher_address=dispatcher_address)) for _ in range(2)
]
Загружаю данные из TFRecord в рабочих процессах и обрабатываю там же, а остальное делаю в TPU, вызываемый из локального DispatchServer:
opt = tf.data.Options()
opt.experimental_deterministic = False
dataset = tf.data.Dataset.from_tensor_slices(filenames).with_options(opt)
dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=AUTO)
dataset = dataset.map(parse_tfrecord, num_parallel_calls=AUTO)
...
return augmented.repeat().apply(tf.data.experimental.service.distribute(processing_mode="distributed_epoch", service=dispatcher.target)).batch(batch_size).prefetch(AUTO)
Тут выдаает предупреждение: /usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py:449: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options. warnings.warn("To make it possible to preserve tf.data options across "
И все потом зависает на:
with strategy.scope():
model = create_model()
model.summary()
Без DispatchServer и WorkerServer и tf.data.experimental.service.distribute все работает нормально. В create_model() создание или загрузка модели. Почему не работает? Вроде все по примерам из руководства TF2.8 делал.