File size: 927 Bytes
e5e24c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def get_device():
device = None
if torch.cuda.is_available():
device = "cuda"
else:
try:
import torch_npu # noqa: F401
device = "npu"
except ImportError:
pass
try:
import torch_mlu # noqa: F401
device = "mlu"
except ImportError:
pass
if device is None:
raise NotImplementedError(
"Supports only CUDA or NPU. If your device is CUDA or NPU, "
"please make sure that your environmental settings are "
"configured correctly."
)
return device
def get_torch_device_module():
device = get_device()
if device == "cuda":
return torch.cuda
elif device == "npu":
return torch.npu
elif device == "mlu":
return torch.mlu
else:
raise NotImplementedError
|