Mungert commited on
Commit
c7f8c69
·
verified ·
1 Parent(s): ba1196d

Update timesfm_backend.py

Browse files
Files changed (1) hide show
  1. timesfm_backend.py +20 -9
timesfm_backend.py CHANGED
@@ -1,6 +1,7 @@
1
  import time
2
  import json
3
  import logging
 
4
  from typing import Any, Dict, List, Optional, Tuple, Sequence
5
 
6
  import numpy as np
@@ -108,11 +109,26 @@ class TimesFMBackend(ChatBackend):
108
  if self._model is not None:
109
  return
110
  try:
 
111
  import timesfm # 2.5 API
112
- # instantiate 2.5 torch model
113
- model = timesfm.TimesFM_2p5_200M_torch()
114
- model.load_checkpoint() # uses built-in 2.5 checkpoint loader
115
- # compile with sane defaults; caller can override via payload
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  cfg = timesfm.ForecastConfig(
117
  max_context=1024,
118
  max_horizon=256,
@@ -123,11 +139,6 @@ class TimesFMBackend(ChatBackend):
123
  fix_quantile_crossing=True,
124
  )
125
  model.compile(cfg)
126
- try:
127
- # .to(device) may be no-op for this wrapper; safe to try
128
- model.to(self.device) # type: ignore[attr-defined]
129
- except Exception:
130
- pass
131
  self._model = model
132
  logger.info("TimesFM 2.5 model loaded on %s", self.device)
133
  except Exception as e:
 
1
  import time
2
  import json
3
  import logging
4
+ import os
5
  from typing import Any, Dict, List, Optional, Tuple, Sequence
6
 
7
  import numpy as np
 
109
  if self._model is not None:
110
  return
111
  try:
112
+ import os
113
  import timesfm # 2.5 API
114
+
115
+ hf_token = getattr(settings, "HF_TOKEN", None) or os.environ.get("HF_TOKEN")
116
+ cache_dir = getattr(settings, "TIMESFM_CACHE_DIR", None)
117
+
118
+ model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
119
+ self.model_id,
120
+ token=hf_token,
121
+ cache_dir=cache_dir,
122
+ local_files_only=False,
123
+ )
124
+
125
+ try:
126
+ # .model holds the underlying nn.Module; fall back to instance if absent.
127
+ target = getattr(model, "model", model)
128
+ target.to(self.device) # type: ignore[arg-type]
129
+ except Exception:
130
+ pass
131
+
132
  cfg = timesfm.ForecastConfig(
133
  max_context=1024,
134
  max_horizon=256,
 
139
  fix_quantile_crossing=True,
140
  )
141
  model.compile(cfg)
 
 
 
 
 
142
  self._model = model
143
  logger.info("TimesFM 2.5 model loaded on %s", self.device)
144
  except Exception as e: