Expected all tensors to be on the same device, PyTorch Lightning

Пытаюсь реализовать свою модель трансформера для перевода, переписываю обучения с чистого pytorch на pytorch lightning (дальше pl). Если во время цикла обучения на pytorch все проходило нормально, то в случае pl появляется такая ошибка (далее привожу в пример ее кусок):

Cell In[70], line 21, in EncoderTransformerLayer.forward(self, value, key, query, mask)
     18 def forward(self, value, key, query, mask):
     19     # attn_output = self.dropout(self.norm(self.attention(value, key, query, mask))) # сюда pre ln
     20     # mlp_output = self.dropout(self.norm(self.mlp(attn_output))) # сюда pre ln
---> 21     value = self.norm_for_v(value)
     22     key = self.norm_for_k(key)
     23     query = self.norm_for_q(query)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/normalization.py:190, in LayerNorm.forward(self, input)
    189 def forward(self, input: Tensor) -> Tensor:
--> 190     return F.layer_norm(
    191         input, self.normalized_shape, self.weight, self.bias, self.eps)

File /opt/conda/lib/python3.10/site-packages/torch/nn/functional.py:2515, in layer_norm(input, normalized_shape, weight, bias, eps)
   2511 if has_torch_function_variadic(input, weight, bias):
   2512     return handle_torch_function(
   2513         layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
   2514     )
-> 2515 return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper_CUDA__native_layer_norm)

Если при использовании pytorch мы используем .cuda() или .to(device) для тензоров для вычисления на gpu, то в случае pl нужно все такие отправления на gpu из кода удалить - после их удаления как раз все и перестало работать.

Класс, где появляется данная ошибка, целиком:

#!g1.1
class EncoderTransformerLayer(pl.LightningModule):
    def __init__(self, hidden_dim: int, num_heads: int, dropout: float = 0.1):
        super().__init__()

        self.attention = AttentionModule(hidden_dim, num_heads)
        self.mlp = MLP(hidden_dim)

        # self.norm = torch.nn.LayerNorm(hidden_dim)
        # self.dropout = torch.nn.Dropout(dropout)
        self.norm_for_v = torch.nn.LayerNorm(hidden_dim)
        self.norm_for_k = torch.nn.LayerNorm(hidden_dim)
        self.norm_for_q = torch.nn.LayerNorm(hidden_dim)

        self.norm_for_attention = torch.nn.LayerNorm(hidden_dim)
        self.norm_for_mlp = torch.nn.LayerNorm(hidden_dim)

    def forward(self, value, key, query, mask):
        # attn_output = self.dropout(self.norm(self.attention(value, key, query, mask))) # сюда pre ln
        # mlp_output = self.dropout(self.norm(self.mlp(attn_output))) # сюда pre ln
        value = self.norm_for_v(value)
        key = self.norm_for_k(key)
        query = self.norm_for_q(query)

        attn_output = self.attention(value, key, query, mask)
        attn_output = self.norm_for_attention(attn_output)

        mlp_output = self.mlp(attn_output)
        mlp_output = self.norm_for_mlp(mlp_output)

        return mlp_output

Из того, что пробовал для исправления - прописать .cpu() для ручной отправки вычислений на процессор, но если раньше это помогло, то в данном куске кода это уже не помогает.


Ответы (0 шт):