Spaces:
Running
Running
dropping inference mode for now
Browse files
app.py
CHANGED
|
@@ -42,9 +42,11 @@ def load_model_from_hub(preset, device):
|
|
| 42 |
cache_dir="/tmp/",
|
| 43 |
)
|
| 44 |
|
|
|
|
|
|
|
| 45 |
model = ScoreFlow(scorenet, device=device, **model_params["PatchFlow"])
|
| 46 |
model.load_state_dict(load_file(hf_checkpoint), strict=True)
|
| 47 |
-
|
| 48 |
return model
|
| 49 |
|
| 50 |
|
|
|
|
| 42 |
cache_dir="/tmp/",
|
| 43 |
)
|
| 44 |
|
| 45 |
+
print("HF SAVE DIR:", hf_checkpoint)
|
| 46 |
+
|
| 47 |
model = ScoreFlow(scorenet, device=device, **model_params["PatchFlow"])
|
| 48 |
model.load_state_dict(load_file(hf_checkpoint), strict=True)
|
| 49 |
+
model = model.eval().requires_grad_(False)
|
| 50 |
return model
|
| 51 |
|
| 52 |
|
msma.py
CHANGED
|
@@ -81,7 +81,7 @@ class EDMScorer(torch.nn.Module):
|
|
| 81 |
|
| 82 |
self.register_buffer("sigma_steps", t_steps.to(torch.float64))
|
| 83 |
|
| 84 |
-
@torch.inference_mode()
|
| 85 |
def forward(
|
| 86 |
self,
|
| 87 |
x,
|
|
@@ -378,7 +378,7 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
| 378 |
|
| 379 |
with open(f"{experiment_dir}/config.json", "w") as f:
|
| 380 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
| 381 |
-
|
| 382 |
# totaliters = int(epochs * train_len)
|
| 383 |
pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
|
| 384 |
step = 0
|
|
|
|
| 81 |
|
| 82 |
self.register_buffer("sigma_steps", t_steps.to(torch.float64))
|
| 83 |
|
| 84 |
+
# @torch.inference_mode()
|
| 85 |
def forward(
|
| 86 |
self,
|
| 87 |
x,
|
|
|
|
| 378 |
|
| 379 |
with open(f"{experiment_dir}/config.json", "w") as f:
|
| 380 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
| 381 |
+
|
| 382 |
# totaliters = int(epochs * train_len)
|
| 383 |
pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
|
| 384 |
step = 0
|