본문 바로가기

인공지능/Python

Pytorch lightning DDP constant가 device에 할당되지 않는 문제 / Tensor for argument #2 'mat1' is on CPU, but expected it to be on GPU

Problem

단일 GPU에서 실행되던 코드를 DDP로 옮기다가 발생한 문제

기존에 init할때 s_i를 to.(device)로 해두었는데, 로그를 print로 찍어보니 cpu로 나와있다.

그렇다. 

init할때는 GPU에 아직 할당되기 전이라서 device가 cpu로 나온 것이다.

forward할때는 device가 CUDA:0 ~ 2로 나오는데 forward에 입력으로 계속 넣어줘서 해결할 수도 있지만 근본적으로 해결을 해야 할 필요가 있었다.

Traceback (most recent call last):
  File "/home/ubuntu/jini1114/DeepMC/trainer.py", line 81, in <module>
    trainer.fit(deepmc, datamodule=dl)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 460, in fit
    self._run(model)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 758, in _run
    self.dispatch()
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 799, in dispatch
    self.accelerator.start_training(self)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in run_stage
    return self.run_train()
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 844, in run_train
    self.run_sanity_check(self.lightning_module)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in run_sanity_check
    self.run_evaluation()
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 967, in run_evaluation
    output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 174, in evaluation_step
    output = self.trainer.accelerator.validation_step(args)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 226, in validation_step
    return self.training_type_plugin.validation_step(*args)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 340, in validation_step
    return self.model(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 705, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 57, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/home/ubuntu/jini1114/DeepMC/net/deepmc.py", line 153, in validation_step
    y_hat = self([X,U])
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/jini1114/DeepMC/net/deepmc.py", line 114, in forward
    c_i = self.Position_based_content_attention(LSTM, self.s_i, i)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/jini1114/DeepMC/net/attention.py", line 50, in forward
    W_a_output = self.W_a(s_i.to(self.device))
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 94, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/functional.py", line 1753, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: Tensor for argument #2 'mat1' is on CPU, but expected it to be on GPU (while checking arguments for addmm)

 

Solution

pytorch lightning 공식문서에 있었는데, 내가 발견을 늦게했다.. 역시 문서를 제대로 읽어야해

The LightningModule knows what device it is on. You can access the reference via self.device. Sometimes it is necessary to store tensors as module attributes. However, if they are not parameters they will remain on the CPU even if the module gets moved to a new device. To prevent that and remain device agnostic, register the tensor as a buffer in your modules’s __init__ method with register_buffer().

__init__함수 안에 register_buffer에 등록하라고 한다.

방법은 간단하다.

#self.s_i = torch.rand((self.batch_size,self.num_decoder_hidden))
self.register_buffer("s_i", torch.rand((self.batch_size,self.num_decoder_hidden)))

#self.cell_state = torch.rand((self.batch_size,self.num_decoder_hidden))
self.register_buffer("cell_state", torch.rand((self.batch_size,self.num_decoder_hidden)))

#self.m_i = torch.rand((self.batch_size, 1))
self.register_buffer("m_i", torch.rand((self.batch_size, 1)))

위처럼 선언하면 string에 넣은 변수명을 self.s_i 이런식으로 불러오면 된다.

그럼 device를 알아서 할당해준다.

https://pytorch-lightning.readthedocs.io/en/1.4.0/advanced/multi_gpu.html

 

Multi-GPU training — PyTorch Lightning 1.4.0 documentation

Multi-GPU training Lightning supports multiple ways of doing distributed training. Preparing your code To train on CPU/GPU/TPU without changing your code, we need to build a few good habits :) Delete .cuda() or .to() calls Delete any calls to .cuda() or .t

pytorch-lightning.readthedocs.io