Spaces:
Running
Running
batch as cmdline argument
Browse files
msma.py
CHANGED
|
@@ -331,12 +331,19 @@ def cache_score_norms(preset, dataset_path, outdir, batch_size):
|
|
| 331 |
default=4,
|
| 332 |
show_default=True,
|
| 333 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
@common_args
|
| 335 |
-
def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
| 336 |
print("using device:", DEVICE)
|
| 337 |
device = DEVICE
|
| 338 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
| 339 |
-
refimg, reflabel = dsobj[0]
|
| 340 |
print(f"Loaded {len(dsobj)} samples from {dataset_path}")
|
| 341 |
|
| 342 |
# Subset of training dataset
|
|
@@ -351,10 +358,10 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
| 351 |
val_ds = Subset(dsobj, range(train_len, train_len + val_len))
|
| 352 |
|
| 353 |
trainiter = torch.utils.data.DataLoader(
|
| 354 |
-
train_ds, batch_size=
|
| 355 |
)
|
| 356 |
testiter = torch.utils.data.DataLoader(
|
| 357 |
-
val_ds, batch_size=
|
| 358 |
)
|
| 359 |
|
| 360 |
scorenet = build_model_from_pickle(preset)
|
|
|
|
| 331 |
default=4,
|
| 332 |
show_default=True,
|
| 333 |
)
|
| 334 |
+
@click.option(
|
| 335 |
+
"--batch_size",
|
| 336 |
+
help="Number of samples per batch",
|
| 337 |
+
metavar="INT",
|
| 338 |
+
type=int,
|
| 339 |
+
default=128,
|
| 340 |
+
show_default=True,
|
| 341 |
+
)
|
| 342 |
@common_args
|
| 343 |
+
def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs):
|
| 344 |
print("using device:", DEVICE)
|
| 345 |
device = DEVICE
|
| 346 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
|
|
|
| 347 |
print(f"Loaded {len(dsobj)} samples from {dataset_path}")
|
| 348 |
|
| 349 |
# Subset of training dataset
|
|
|
|
| 358 |
val_ds = Subset(dsobj, range(train_len, train_len + val_len))
|
| 359 |
|
| 360 |
trainiter = torch.utils.data.DataLoader(
|
| 361 |
+
train_ds, batch_size=batch_size, num_workers=4, prefetch_factor=2, shuffle=True
|
| 362 |
)
|
| 363 |
testiter = torch.utils.data.DataLoader(
|
| 364 |
+
val_ds, batch_size=batch_size*2, num_workers=4, prefetch_factor=2
|
| 365 |
)
|
| 366 |
|
| 367 |
scorenet = build_model_from_pickle(preset)
|