SOM - Ошибка при обучении данных (ошибка с размерностью)

Изучаю в данный момент машинное обучение и возникла необходимость сделать программу для самоорганизующихся карт Кохонена. В качестве датасета используются изображения, сжатые до 64х64, которые были конвертированы в пиксельные массивы с градацией серого.

IMG_SIZE = (64, 64) # Берем изображение 64х64 для наглядного обучения

digits=[] 
p = Path('dataset/simpsons_dataset/abraham_grampa_simpson').rglob("*.jpg") # Берем из директории все файлы с расширением .jpg
for item in p: # Проходимся по всем папкам (ищем файлы)
    if item.is_file(): # Проверка, является ли искомое изображение файлом (не директорией, ха)
        img = cv2.imread(str(item)) # Для начала по указанному пути директорий загружаем изображение
        img = cv2.resize(img, IMG_SIZE) # Сжимаем изображение до 64х64
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # Преобразуем цветное изображение в массив с градацией серого
        digits.append(img) # Добавляем изображения в массив для дальнейшей обработки
digits=np.array(digits) # Преобразовываем наш массив для дальнейшей обработки матрицы изображений

Сама нейронная сеть, которая используется для алгоритма обучения сетей Кохонена.

class SOM(nn.Module):
    def __init__(self, input_size, out_size, lr=0.3, sigma=None):
        '''
        :param input_size:
        :param out_size:
        :param lr:
        :param sigma:
        '''
        super(SOM, self).__init__()
        self.input_size = input_size # Инициализируем параметр входного размера
        self.out_size = out_size # Размерность выходной матрицы

        self.lr = lr # скорость обучения
        if sigma is None: # 
            self.sigma = max(out_size) / 2
        else:
            self.sigma = float(sigma)

        print(f"Self Weight In Size: {input_size}")
        print(f"Self Weight Out Size: {out_size[0] * out_size[1]}")
        self.weight = nn.Parameter(torch.randn(input_size, out_size[0] * out_size[1]), requires_grad=False)
        print(f"Weights: {self.weight}")
        self.locations = nn.Parameter(torch.Tensor(list(self.get_map_index())), requires_grad=False)
        print(f"Locations: {self.locations}")
        self.pdist_fn = nn.PairwiseDistance(p=2)
        print(f"Pairwise Distance Function: {self.pdist_fn}")

    def get_map_index(self):
        '''Two-dimensional mapping function'''
        for x in range(self.out_size[0]):
            for y in range(self.out_size[1]):
                yield (x, y)

    def _neighborhood_fn(self, input, current_sigma): # Функция проверки расстояния между нейронами-победителями и соседами-нейронами
        '''e^(-(input / sigma^2))''' # Для получения наилучшего результата по нахождению нейрона-победителя используем Гауссову функцию
        print(f"Current Sigma: {current_sigma}")
        input.div_(current_sigma ** 2)
        input.neg_()
        input.exp_()
        
        print(f"Neighborhood Sigma Input Value: {input}")

        return input

    def forward(self, input):
        '''
        # Функция, ответственная за нахождение наименьшего расстояния от вектора внутреннего слоя, который находится на карте "проекции" до входного вектора.
        Find the location of best matching unit.
        :param input: data
        :return: location of best matching unit, loss
        '''
        batch_size = input.size()[0]
        print(f"Input Size: {input.size()}")
        print(f"Batch Size: {batch_size}")
        input = input.view(batch_size, -1, 1)
        batch_weight = self.weight.expand(batch_size, -1, -1)
        print(f"batch weight: {batch_weight}")

        dists = self.pdist_fn(input, batch_weight)
        print(f"Dists: {dists}")
        # Find best matching unit
        losses, bmu_indexes = dists.min(dim=1, keepdim=True) # Именно эта функция отвечает за нахождение наименьшей дистанции
        bmu_locations = self.locations[bmu_indexes] # <--- Здесь возникает ошибка при обучении нейросети

        return bmu_locations, losses.sum().div_(batch_size).item()

    def self_organizing(self, input, current_iter, max_iter):
        '''
        Train the Self Oranizing Map(SOM)
        :param input: training data
        :param current_iter: current epoch of total epoch
        :param max_iter: total epoch
        :return: loss (minimum distance)
        '''
        batch_size = input.size()[0]
        print(f"Batch Size for SOM: {batch_size}")
        #Set learning rate
        iter_correction = 1.0 - current_iter / max_iter
        print(f"SOM Iteration correction - Корректировка итераций: {iter_correction}")
        lr = self.lr * iter_correction
        print(f"SOM Learning Rate: {lr}")
        sigma = self.sigma * iter_correction
        print(f"SOM Sima: {sigma}")

        #Find best matching unit
        bmu_locations, loss = self.forward(input)
        print(f"SOM Best Matching Unit Locations: {bmu_locations}")
        print(f"SOM Losses: {loss}")

        print(f"SOM - Self.Locations: {self.locations.float()}")
        print(f"SOM - Best Matching Unit.Locations: {bmu_locations.float()}")
        distance_squares = self.locations.float() - bmu_locations.float()
        print(f"SOM Distance Squares: {distance_squares}")
        distance_squares.pow_(2)
        print(f"SOM Dist Squares Pow 2: {distance_squares.pow_(2)}")
        distance_squares = torch.sum(distance_squares, dim=2)
        print(f"SOM Dist Squares Sum with DIM=2: {distance_squares}")

        lr_locations = self._neighborhood_fn(distance_squares, sigma)
        print(f"SOM Leatning Rate - Neighborhood FN: {lr_locations}")
        lr_locations.mul_(lr).unsqueeze_(1)
        print(f"SOM unsqueezed (1): {lr_locations.mul_(lr).unsqueeze_(1)}")

        delta = lr_locations * (input.unsqueeze(2) - self.weight)
        print(f"SOM Input Value Unsqueezed (2): {input.unsqueeze(2)}")
        print(f"SOM Delta: {lr_locations * (input.unsqueeze(2) - self.weight)}")

        delta = delta.sum(dim=0)
        print(f"SOM Delta Sum with DIM=0: {delta}")
        delta.div_(batch_size)
        print(f"SOM delta div batch size: {delta.div_(batch_size)}")
        self.weight.data.add_(delta)
        print(f"self.weight.data ADD: {self.weight.data.add_(delta)}")

        return loss

Параметры для обучения:

data = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(torch.Tensor(digits)),
    batch_size=32,
    shuffle=True) # Инициализация датасета

total_epoch=150
losses = []
m = 2
n = 5

som = SOM(input_size = (digits[0].shape[0] * digits[0].shape[1] * 1 * 1), out_size = (m, n , 1)) 

Как я обучаю нейросеть:

image = torch.from_numpy(digits[0]) # Берем первый элемент массива изображений
print("Level A")
som.train() 
print("Level B")

for epoch in tqdm(range(total_epoch)):
    loss = 0
    iter_ = 0
    #Train with each vector one by one
    for (image, ) in data:
        print("Level C") 
        print(f"Image Size(): {image.size()[0]}")
        print(f"Image View Size(): {image.view(image.size()[0], -1)}")
        loss += som.self_organizing(image.view(image.size()[0], -1), epoch, total_epoch) # При получения суммы потерь при обучении нейросети происходит ошибка.
        print("Level D")
        iter_ += 1
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print("Level E")
        with out:
            plt.figure()
            # buf = som.weight.view(128,128,-1).argmax(2).float() / 10
            buf = som.weight.view(64,64,-1).argmax(2).float() / 10
            print("Level F")
            out.clear_output(True)
            plt.imshow(buf, cmap='Accent')
            print("Level G")
            plt.title(f'{epoch+1}/{total_epoch} iters (Error: {loss/n})')
            plt.show()
            
        with out_clusters:
            out_clusters.clear_output(True)
            plt.figure()
            print("Level H")
            for i in range(m*n):
                print("Level I")
                plt.subplot(2, 5, i+1)
                plt.imshow(64, 64, -1)[..., i]
                # plt.imshow(som.weight.view(128, 128, -1)[..., i])
                print("Level J")
            plt.show()
    with out_history:
        print("Level K")
        out_history.clear_output(True)
        plt.figure(figsize=(10, 3))
        print("Level L")
        plt.plot(list(range(len(losses))), (losses), 'b-')
        plt.scatter(list(range(len(losses))), (losses))
        plt.show()
        print("Level M")
    losses.append(loss / n)
    print("Level N")

Сама ошибка:

--------------------------------------------------------------------------- IndexError Traceback (most recent call last) /var/folders/y5/pchlt5l12gx5_65v0_bn_fcm0000gn/T/ipykernel_12532/3074031020.py in 12 print(f"Image Size(): {image.size()[0]}") 13 print(f"Image View Size(): {image.view(image.size()[0], -1)}") ---> 14 loss += som.self_organizing(image.view(image.size()[0], -1), epoch, total_epoch) 15 print("Level D") 16 iter_ += 1

/var/folders/y5/pchlt5l12gx5_65v0_bn_fcm0000gn/T/ipykernel_12532/70923719.py in self_organizing(self, input, current_iter, max_iter) 83 84 #Find best matching unit ---> 85 bmu_locations, loss = self.forward(input) 86 print(f"SOM Best Matching Unit Locations: {bmu_locations}") 87 print(f"SOM Losses: {loss}")

/var/folders/y5/pchlt5l12gx5_65v0_bn_fcm0000gn/T/ipykernel_12532/70923719.py in forward(self, input) 60 # Find best matching unit 61 losses, bmu_indexes = dists.min(dim=1, keepdim=True) ---> 62 bmu_locations = self.locations[bmu_indexes] 63 64 return bmu_locations, losses.sum().div_(batch_size).item()

IndexError: index 2842 is out of bounds for dimension 0 with size 10

Скажите, пожалуйста: возможно ли, что вся ошибка кроется в задании размерности выходных параметров? Если да, то уточните, на какие их стоит поменять... Заранее благодарен!


Ответы (0 шт):