| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| def get_device(): | |
| device = None | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| try: | |
| import torch_npu # noqa: F401 | |
| device = "npu" | |
| except ImportError: | |
| pass | |
| try: | |
| import torch_mlu # noqa: F401 | |
| device = "mlu" | |
| except ImportError: | |
| pass | |
| if device is None: | |
| raise NotImplementedError( | |
| "Supports only CUDA or NPU. If your device is CUDA or NPU, " | |
| "please make sure that your environmental settings are " | |
| "configured correctly." | |
| ) | |
| return device | |
| def get_torch_device_module(): | |
| device = get_device() | |
| if device == "cuda": | |
| return torch.cuda | |
| elif device == "npu": | |
| return torch.npu | |
| elif device == "mlu": | |
| return torch.mlu | |
| else: | |
| raise NotImplementedError | |