Определение собственного датасета из синтетики в pytorch

Имеется набор изображений jpeg с синтезированными графиками, разбитых на 3 класса по 300 штук в каждом + 90 графиков для тестовой выборки.

Планируется создать нейросеть, определяющая класс графика по изображению на входе. Никак не могу разобраться, как создать в Pytorch собственный датасет с изображениями, т.к.все туториалы в интернете про то, как использовать готовые сеты из интернета.

Прошу привести пример реализации датасета из Jpeg или любую ссылку на материал по этой теме.

Спасибо.

Прикладываю текущую иерархию папок с изображениями + пример изображения каждого класса

Иерархия изображений

Класс Downs

Класс Peaks

Класс Rises


Ответы (1 шт):

Автор решения: MaxU

Воспользуйтесь torchvision.datasets.ImageFolder.

Пример:

transform_dict = {
        'src': transforms.Compose(
        [transforms.RandomResizedCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ]),
        'tar': transforms.Compose(
        [transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ])
}
data = datasets.ImageFolder(root=root_path + dir, transform=transform_dict[phase])
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, 
                                          drop_last=False, num_workers=4)
→ Ссылка