В каком формате нужно передавать датасет в Trainer из библиотеки transformers?

У меня датасет с изображениями и аннотациями в формате COCO и при попытке дообучить модель trainer ругается на формат датасета. Я пытался предобработать датасет, но это проблему не решило. Вот мой код:

def id_by_name(filename, anns):
    for img in anns["images"]:
        if img["file_name"] == filename:
            return img["id"]
    return None


def mask_to_tensor(mask):
    # Округляем значения до ближайшего целого числа
    mask_rounded = np.round(mask).astype(np.uint8)
    
    # Преобразуем в тензор и нормализуем
    tensor = torch.tensor(mask_rounded).float() / 255.0
    
    return tensor


def prepare_data(examples, ann_path):
    with open(ann_path, 'r') as f:
        annotations = json.load(f) 
        
    labels = list()
    for img in examples['image']:
        label = {
                'class_labels': [],
                'masks': []
            }
        ann = annotations["annotations"][id_by_name(basename(img.filename), annotations)]
        masks = list()
        for mask in ann["segmentation"]:
            mask_tensor = mask_to_tensor(mask)
            masks.append(mask_tensor) 

        label['class_labels'].append(torch.tensor([ann['category_id']]))
        label['masks'].append(mask)
        labels.append(label)

    # Проверяем, является ли пример LazyBatch
    if isinstance(examples['image'][0], datasets.formatting.formatting.LazyBatch):
        # Загружаем изображения в память
        images = [
            PIL.Image.open(path) 
            for path in examples['image']
        ]
    else:
        images = examples['image']

    # Применяем трансформации
    transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize((224, 224)),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                       std=[0.229, 0.224, 0.225])
                ])

    valid_images = []
    for img in images:
        valid_images.append(transform(img))
    
    return {
        'pixel_values': torch.stack(valid_images),
        'labels': labels 
    }   

train = load_dataset(
    "train",
    split="train",
)

train = train.map(
    lambda x: prepare_data(x, "./train_annotations.json"),
    batched=True,
    num_proc=2,
    load_from_cache_file=False
).remove_columns("image")

val = load_dataset(
    "val",
    split="train"
)

val = val.map(
    lambda x: prepare_data(x, "./val_annotations.json"),
    batched=True,
    num_proc=2,
    load_from_cache_file=False
).remove_columns("image")

model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")

training_args = TrainingArguments(
    output_dir="transformers_train_logs",
    evaluation_strategy="epoch",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=5,
    report_to="none",
    remove_unused_columns=False
)

metrics = evaluate.load("f1")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)

    return metric.compute(predictions=predictions, references=labels)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=val,
    compute_metrics=compute_metrics
)

trainer.train()

Ответы (0 шт):