본문 바로가기

Pytorch/Errors

[Pytorch] RuntimeError: Attempting to deserialize object on CUDA device 2 but torch.cuda.device_count() is 1. Please use torch.load with map_location to map your storages to an existing device.

반응형

파이토치를 통해 코드를 짜다보면 모델을 불러오기 위해 load하는 경우가 상당히 많습니다. 그런데 종종 아래와 같은 에러가 발생합니다.

# 에러 메세지
RuntimeError: Attempting to deserialize object on CUDA device 2 but torch.cuda.device_count()
is 1. Please use torch.load with map_location to map your storages to an existing device.
# 에러가 발생한 코드
checkpoint = torch.load(args.pretrained)

이러한 에러가 발생하는 이유는 load를 할때, 불러오는 모델을 학습한 환경에서 2개 이상의 gpu가 사용됐기 때문에, 한 개의 gpu에서 load를 하려고 할 때 발생하는 오류로 추측됩니다.

 

에러 해결방법은 에러 메세지와 같이 map_location을 추가해주면 됩니다.

checkpoint = torch.load(args.pretrained,map_location='cuda:0')

 

반응형