Почему numpy по-разному работает для магических методов

class Tensor:

    def __init__(self, value, requires_grad=False, local_gradients=None):
        self.value = value

    def __add__(self, other):
        print("adding")
        other = other if isinstance(other, Tensor) else Tensor(other)
        value = self.value + other.value
        return Tensor(value)

    def __radd__(self, other):
        print("radding")
        other = other if isinstance(other, Tensor) else Tensor(other)
        value = self.value + other.value
        return Tensor(value)

    @staticmethod
    def ones(shape):
        return Tensor(np.ones(shape))


a = np.random.randn(4, 3, 2, 1)
b = Tensor.ones((3, 2, 1))
b + a
a + b

Получаем такой вывод

adding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding
radding

Видим, что __add__ выполнилсь один раз. Как я и хотел А вот __radd__ выполнилось 24 раза (у нас как раз 24 элемента в массиве), то он как будто выполнил __radd__ поэлементно. Это не то что мы ожидаем. Есть варианты исправить такое поведение?


Ответы (2 шт):

Автор решения: CrazyElf
object1 + object2

Суть в том, что метод __radd__ у object2 будет вызван только в том случае, если не определён метод __add__ у object1. А он, видимо, определён в Numpy таким образом, что он просто пытается каждый элемент своего тензора object1 сложить с вашим объектом object2, для чего и вызывается __radd__ с каждым отдельном элементом тензора object1.

На английском SO пишут, что это можно победить, если унаследоваться от np.ndarray и переопределить метод __add__, но у меня что-то сходу не заработала такая схема, нужно разбираться.

→ Ссылка
Автор решения: Pak Uula

Весь фокус в том, что метод __radd__ вызывается только в том случае, если __add__ выбрасывает исключение NotImplemented. Но посмотрите, для примера, на вашу реализацию __add__: разве вы выбрасываете исключение? Нет, ваш метод сложения пытается привести правый аргумент к необходимому типу. Поэтому в выражениях, где слева ваш тензор, метод radd правого операнда никогда не будет вызван.

В случае с ndarray то же самое. Метод ndarray.__add__ пытается постоить из правого операнда массив, вызывая np.asarray(). Эта функция либо строит ndarray, либо возвращает объект без изменения. В случае с тензором происходит второй случай - он возвращается как есть. Функция ndarray.__add__ видит, что правый операнд не является массивом, и выполняет сложение массива со скаляром, то есть складывает каждый элемент массива с вашим тензором.

Так как для float нет операции сложения с Tensor, вызвается Tensor.__radd__.

Вот немного допиленная реализация сложения, которая выводит типы операндов:

import numpy as np

class Tensor:

    def __init__(self, value, requires_grad=False, local_gradients=None):
        self.value = value

    def __add__(self, other):
        print(f"adding: {type(self)} + {type(other)}")
        other = other if isinstance(other, Tensor) else Tensor(other)
        value = self.value + other.value
        return Tensor(value)

    def __radd__(self, other):
        print(f"radding: {type(other)} + {type(self)}")
        other = other if isinstance(other, Tensor) else Tensor(other)
        value = self.value + other.value
        return Tensor(value)

    @staticmethod
    def ones(shape):
        return Tensor(np.ones(shape))


a = np.random.randn(4, 3, 2, 1)
b = Tensor.ones((3, 2, 1))
b + a
a + b

Получаем:

adding: <class '__main__.Tensor'> + <class 'numpy.ndarray'>
radding: <class 'float'> + <class '__main__.Tensor'>
radding: <class 'float'> + <class '__main__.Tensor'>
.
.
.

Получается, numpy интепретирует ваш тензор как скаляр и выполняет операцию прибавления скаляра, то есть поэлементное сложение.

Сделать с этим вы, ИМХО, ничего не можете. Поэтому если вы хотите векторную операцию с массивом, вам нужно размещать тензор на первом месте.

Я бы вообще не стал перегружать оператор +, а сделал бы Tensor.add_array(a), чтобы случайно не напутать последовательность операндов и не вылавливать ошибку часами.

→ Ссылка