Update timesfm_backend.py
Browse files- timesfm_backend.py +20 -9
timesfm_backend.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import time
|
| 2 |
import json
|
| 3 |
import logging
|
|
|
|
| 4 |
from typing import Any, Dict, List, Optional, Tuple, Sequence
|
| 5 |
|
| 6 |
import numpy as np
|
|
@@ -108,11 +109,26 @@ class TimesFMBackend(ChatBackend):
|
|
| 108 |
if self._model is not None:
|
| 109 |
return
|
| 110 |
try:
|
|
|
|
| 111 |
import timesfm # 2.5 API
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
cfg = timesfm.ForecastConfig(
|
| 117 |
max_context=1024,
|
| 118 |
max_horizon=256,
|
|
@@ -123,11 +139,6 @@ class TimesFMBackend(ChatBackend):
|
|
| 123 |
fix_quantile_crossing=True,
|
| 124 |
)
|
| 125 |
model.compile(cfg)
|
| 126 |
-
try:
|
| 127 |
-
# .to(device) may be no-op for this wrapper; safe to try
|
| 128 |
-
model.to(self.device) # type: ignore[attr-defined]
|
| 129 |
-
except Exception:
|
| 130 |
-
pass
|
| 131 |
self._model = model
|
| 132 |
logger.info("TimesFM 2.5 model loaded on %s", self.device)
|
| 133 |
except Exception as e:
|
|
|
|
| 1 |
import time
|
| 2 |
import json
|
| 3 |
import logging
|
| 4 |
+
import os
|
| 5 |
from typing import Any, Dict, List, Optional, Tuple, Sequence
|
| 6 |
|
| 7 |
import numpy as np
|
|
|
|
| 109 |
if self._model is not None:
|
| 110 |
return
|
| 111 |
try:
|
| 112 |
+
import os
|
| 113 |
import timesfm # 2.5 API
|
| 114 |
+
|
| 115 |
+
hf_token = getattr(settings, "HF_TOKEN", None) or os.environ.get("HF_TOKEN")
|
| 116 |
+
cache_dir = getattr(settings, "TIMESFM_CACHE_DIR", None)
|
| 117 |
+
|
| 118 |
+
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
|
| 119 |
+
self.model_id,
|
| 120 |
+
token=hf_token,
|
| 121 |
+
cache_dir=cache_dir,
|
| 122 |
+
local_files_only=False,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
# .model holds the underlying nn.Module; fall back to instance if absent.
|
| 127 |
+
target = getattr(model, "model", model)
|
| 128 |
+
target.to(self.device) # type: ignore[arg-type]
|
| 129 |
+
except Exception:
|
| 130 |
+
pass
|
| 131 |
+
|
| 132 |
cfg = timesfm.ForecastConfig(
|
| 133 |
max_context=1024,
|
| 134 |
max_horizon=256,
|
|
|
|
| 139 |
fix_quantile_crossing=True,
|
| 140 |
)
|
| 141 |
model.compile(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
self._model = model
|
| 143 |
logger.info("TimesFM 2.5 model loaded on %s", self.device)
|
| 144 |
except Exception as e:
|