Expected all tensors to be on the same device, PyTorch Lightning
Пытаюсь реализовать свою модель трансформера для перевода, переписываю обучения с чистого pytorch на pytorch lightning (дальше pl). Если во время цикла обучения на pytorch все проходило нормально, то в случае pl появляется такая ошибка (далее привожу в пример ее кусок):
Cell In[70], line 21, in EncoderTransformerLayer.forward(self, value, key, query, mask)
18 def forward(self, value, key, query, mask):
19 # attn_output = self.dropout(self.norm(self.attention(value, key, query, mask))) # сюда pre ln
20 # mlp_output = self.dropout(self.norm(self.mlp(attn_output))) # сюда pre ln
---> 21 value = self.norm_for_v(value)
22 key = self.norm_for_k(key)
23 query = self.norm_for_q(query)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/normalization.py:190, in LayerNorm.forward(self, input)
189 def forward(self, input: Tensor) -> Tensor:
--> 190 return F.layer_norm(
191 input, self.normalized_shape, self.weight, self.bias, self.eps)
File /opt/conda/lib/python3.10/site-packages/torch/nn/functional.py:2515, in layer_norm(input, normalized_shape, weight, bias, eps)
2511 if has_torch_function_variadic(input, weight, bias):
2512 return handle_torch_function(
2513 layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
2514 )
-> 2515 return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper_CUDA__native_layer_norm)
Если при использовании pytorch мы используем .cuda() или .to(device) для тензоров для вычисления на gpu, то в случае pl нужно все такие отправления на gpu из кода удалить - после их удаления как раз все и перестало работать.
Класс, где появляется данная ошибка, целиком:
#!g1.1
class EncoderTransformerLayer(pl.LightningModule):
def __init__(self, hidden_dim: int, num_heads: int, dropout: float = 0.1):
super().__init__()
self.attention = AttentionModule(hidden_dim, num_heads)
self.mlp = MLP(hidden_dim)
# self.norm = torch.nn.LayerNorm(hidden_dim)
# self.dropout = torch.nn.Dropout(dropout)
self.norm_for_v = torch.nn.LayerNorm(hidden_dim)
self.norm_for_k = torch.nn.LayerNorm(hidden_dim)
self.norm_for_q = torch.nn.LayerNorm(hidden_dim)
self.norm_for_attention = torch.nn.LayerNorm(hidden_dim)
self.norm_for_mlp = torch.nn.LayerNorm(hidden_dim)
def forward(self, value, key, query, mask):
# attn_output = self.dropout(self.norm(self.attention(value, key, query, mask))) # сюда pre ln
# mlp_output = self.dropout(self.norm(self.mlp(attn_output))) # сюда pre ln
value = self.norm_for_v(value)
key = self.norm_for_k(key)
query = self.norm_for_q(query)
attn_output = self.attention(value, key, query, mask)
attn_output = self.norm_for_attention(attn_output)
mlp_output = self.mlp(attn_output)
mlp_output = self.norm_for_mlp(mlp_output)
return mlp_output
Из того, что пробовал для исправления - прописать .cpu() для ручной отправки вычислений на процессор, но если раньше это помогло, то в данном куске кода это уже не помогает.