Spaces:
Running
Running
+ porting in msma files
Browse files+ adding flow model utils
- app.py +1 -1
- dataset.py +269 -0
- flowutils.py +263 -0
- scorer.py → msma.py +203 -53
app.py
CHANGED
|
@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
|
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
|
| 9 |
-
from
|
| 10 |
|
| 11 |
|
| 12 |
@cache
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
|
| 9 |
+
from msma import build_model, config_presets
|
| 10 |
|
| 11 |
|
| 12 |
@cache
|
dataset.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Streaming images and labels from datasets created with dataset_tool.py."""
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import zipfile
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import PIL.Image
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
import dnnlib
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import pyspng
|
| 22 |
+
except ImportError:
|
| 23 |
+
pyspng = None
|
| 24 |
+
|
| 25 |
+
# ----------------------------------------------------------------------------
|
| 26 |
+
# Abstract base class for datasets.
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Dataset(torch.utils.data.Dataset):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
name, # Name of the dataset.
|
| 33 |
+
raw_shape, # Shape of the raw image data (NCHW).
|
| 34 |
+
use_labels=True, # Enable conditioning labels? False = label dimension is zero.
|
| 35 |
+
max_size=None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
|
| 36 |
+
xflip=False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
|
| 37 |
+
random_seed=0, # Random seed to use when applying max_size.
|
| 38 |
+
cache=False, # Cache images in CPU memory?
|
| 39 |
+
):
|
| 40 |
+
self._name = name
|
| 41 |
+
self._raw_shape = list(raw_shape)
|
| 42 |
+
self._use_labels = use_labels
|
| 43 |
+
self._cache = cache
|
| 44 |
+
self._cached_images = dict() # {raw_idx: np.ndarray, ...}
|
| 45 |
+
self._raw_labels = None
|
| 46 |
+
self._label_shape = None
|
| 47 |
+
|
| 48 |
+
# Apply max_size.
|
| 49 |
+
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
|
| 50 |
+
if (max_size is not None) and (self._raw_idx.size > max_size):
|
| 51 |
+
np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
|
| 52 |
+
self._raw_idx = np.sort(self._raw_idx[:max_size])
|
| 53 |
+
|
| 54 |
+
# Apply xflip.
|
| 55 |
+
self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
|
| 56 |
+
if xflip:
|
| 57 |
+
self._raw_idx = np.tile(self._raw_idx, 2)
|
| 58 |
+
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
|
| 59 |
+
|
| 60 |
+
def _get_raw_labels(self):
|
| 61 |
+
if self._raw_labels is None:
|
| 62 |
+
self._raw_labels = self._load_raw_labels() if self._use_labels else None
|
| 63 |
+
if self._raw_labels is None:
|
| 64 |
+
self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
|
| 65 |
+
assert isinstance(self._raw_labels, np.ndarray)
|
| 66 |
+
assert self._raw_labels.shape[0] == self._raw_shape[0]
|
| 67 |
+
assert self._raw_labels.dtype in [np.float32, np.int64]
|
| 68 |
+
if self._raw_labels.dtype == np.int64:
|
| 69 |
+
assert self._raw_labels.ndim == 1
|
| 70 |
+
assert np.all(self._raw_labels >= 0)
|
| 71 |
+
return self._raw_labels
|
| 72 |
+
|
| 73 |
+
def close(self): # to be overridden by subclass
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
def _load_raw_image(self, raw_idx): # to be overridden by subclass
|
| 77 |
+
raise NotImplementedError
|
| 78 |
+
|
| 79 |
+
def _load_raw_labels(self): # to be overridden by subclass
|
| 80 |
+
raise NotImplementedError
|
| 81 |
+
|
| 82 |
+
def __getstate__(self):
|
| 83 |
+
return dict(self.__dict__, _raw_labels=None)
|
| 84 |
+
|
| 85 |
+
def __del__(self):
|
| 86 |
+
try:
|
| 87 |
+
self.close()
|
| 88 |
+
except:
|
| 89 |
+
pass
|
| 90 |
+
|
| 91 |
+
def __len__(self):
|
| 92 |
+
return self._raw_idx.size
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, idx):
|
| 95 |
+
raw_idx = self._raw_idx[idx]
|
| 96 |
+
image = self._cached_images.get(raw_idx, None)
|
| 97 |
+
if image is None:
|
| 98 |
+
image = self._load_raw_image(raw_idx)
|
| 99 |
+
if self._cache:
|
| 100 |
+
self._cached_images[raw_idx] = image
|
| 101 |
+
assert isinstance(image, np.ndarray)
|
| 102 |
+
assert list(image.shape) == self._raw_shape[1:]
|
| 103 |
+
if self._xflip[idx]:
|
| 104 |
+
assert image.ndim == 3 # CHW
|
| 105 |
+
image = image[:, :, ::-1]
|
| 106 |
+
return image.copy(), self.get_label(idx)
|
| 107 |
+
|
| 108 |
+
def get_label(self, idx):
|
| 109 |
+
label = self._get_raw_labels()[self._raw_idx[idx]]
|
| 110 |
+
if label.dtype == np.int64:
|
| 111 |
+
onehot = np.zeros(self.label_shape, dtype=np.float32)
|
| 112 |
+
onehot[label] = 1
|
| 113 |
+
label = onehot
|
| 114 |
+
return label.copy()
|
| 115 |
+
|
| 116 |
+
def get_details(self, idx):
|
| 117 |
+
d = dnnlib.EasyDict()
|
| 118 |
+
d.raw_idx = int(self._raw_idx[idx])
|
| 119 |
+
d.xflip = int(self._xflip[idx]) != 0
|
| 120 |
+
d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
|
| 121 |
+
return d
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def name(self):
|
| 125 |
+
return self._name
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def image_shape(self): # [CHW]
|
| 129 |
+
return list(self._raw_shape[1:])
|
| 130 |
+
|
| 131 |
+
@property
|
| 132 |
+
def num_channels(self):
|
| 133 |
+
assert len(self.image_shape) == 3 # CHW
|
| 134 |
+
return self.image_shape[0]
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def resolution(self):
|
| 138 |
+
assert len(self.image_shape) == 3 # CHW
|
| 139 |
+
assert self.image_shape[1] == self.image_shape[2]
|
| 140 |
+
return self.image_shape[1]
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def label_shape(self):
|
| 144 |
+
if self._label_shape is None:
|
| 145 |
+
raw_labels = self._get_raw_labels()
|
| 146 |
+
if raw_labels.dtype == np.int64:
|
| 147 |
+
self._label_shape = [int(np.max(raw_labels)) + 1]
|
| 148 |
+
else:
|
| 149 |
+
self._label_shape = raw_labels.shape[1:]
|
| 150 |
+
return list(self._label_shape)
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def label_dim(self):
|
| 154 |
+
assert len(self.label_shape) == 1
|
| 155 |
+
return self.label_shape[0]
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def has_labels(self):
|
| 159 |
+
return any(x != 0 for x in self.label_shape)
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def has_onehot_labels(self):
|
| 163 |
+
return self._get_raw_labels().dtype == np.int64
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ----------------------------------------------------------------------------
|
| 167 |
+
# Dataset subclass that loads images recursively from the specified directory
|
| 168 |
+
# or ZIP file.
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class ImageFolderDataset(Dataset):
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
path, # Path to directory or zip.
|
| 175 |
+
resolution=None, # Ensure specific resolution, None = anything goes.
|
| 176 |
+
**super_kwargs, # Additional arguments for the Dataset base class.
|
| 177 |
+
):
|
| 178 |
+
self._path = path
|
| 179 |
+
self._zipfile = None
|
| 180 |
+
|
| 181 |
+
if os.path.isdir(self._path):
|
| 182 |
+
self._type = "dir"
|
| 183 |
+
self._all_fnames = {
|
| 184 |
+
os.path.relpath(os.path.join(root, fname), start=self._path)
|
| 185 |
+
for root, _dirs, files in os.walk(self._path)
|
| 186 |
+
for fname in files
|
| 187 |
+
}
|
| 188 |
+
elif self._file_ext(self._path) == ".zip":
|
| 189 |
+
self._type = "zip"
|
| 190 |
+
self._all_fnames = set(self._get_zipfile().namelist())
|
| 191 |
+
else:
|
| 192 |
+
raise IOError("Path must point to a directory or zip")
|
| 193 |
+
|
| 194 |
+
PIL.Image.init()
|
| 195 |
+
supported_ext = PIL.Image.EXTENSION.keys() | {".npy"}
|
| 196 |
+
self._image_fnames = sorted(
|
| 197 |
+
fname
|
| 198 |
+
for fname in self._all_fnames
|
| 199 |
+
if self._file_ext(fname) in supported_ext
|
| 200 |
+
)
|
| 201 |
+
if len(self._image_fnames) == 0:
|
| 202 |
+
raise IOError("No image files found in the specified path")
|
| 203 |
+
|
| 204 |
+
name = os.path.splitext(os.path.basename(self._path))[0]
|
| 205 |
+
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
|
| 206 |
+
if resolution is not None and (
|
| 207 |
+
raw_shape[2] != resolution or raw_shape[3] != resolution
|
| 208 |
+
):
|
| 209 |
+
raise IOError("Image files do not match the specified resolution")
|
| 210 |
+
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
|
| 211 |
+
|
| 212 |
+
@staticmethod
|
| 213 |
+
def _file_ext(fname):
|
| 214 |
+
return os.path.splitext(fname)[1].lower()
|
| 215 |
+
|
| 216 |
+
def _get_zipfile(self):
|
| 217 |
+
assert self._type == "zip"
|
| 218 |
+
if self._zipfile is None:
|
| 219 |
+
self._zipfile = zipfile.ZipFile(self._path)
|
| 220 |
+
return self._zipfile
|
| 221 |
+
|
| 222 |
+
def _open_file(self, fname):
|
| 223 |
+
if self._type == "dir":
|
| 224 |
+
return open(os.path.join(self._path, fname), "rb")
|
| 225 |
+
if self._type == "zip":
|
| 226 |
+
return self._get_zipfile().open(fname, "r")
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
+
def close(self):
|
| 230 |
+
try:
|
| 231 |
+
if self._zipfile is not None:
|
| 232 |
+
self._zipfile.close()
|
| 233 |
+
finally:
|
| 234 |
+
self._zipfile = None
|
| 235 |
+
|
| 236 |
+
def __getstate__(self):
|
| 237 |
+
return dict(super().__getstate__(), _zipfile=None)
|
| 238 |
+
|
| 239 |
+
def _load_raw_image(self, raw_idx):
|
| 240 |
+
fname = self._image_fnames[raw_idx]
|
| 241 |
+
ext = self._file_ext(fname)
|
| 242 |
+
with self._open_file(fname) as f:
|
| 243 |
+
if ext == ".npy":
|
| 244 |
+
image = np.load(f)
|
| 245 |
+
image = image.reshape(-1, *image.shape[-2:])
|
| 246 |
+
elif ext == ".png" and pyspng is not None:
|
| 247 |
+
image = pyspng.load(f.read())
|
| 248 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
| 249 |
+
else:
|
| 250 |
+
image = np.array(PIL.Image.open(f))
|
| 251 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
| 252 |
+
return image
|
| 253 |
+
|
| 254 |
+
def _load_raw_labels(self):
|
| 255 |
+
fname = "dataset.json"
|
| 256 |
+
if fname not in self._all_fnames:
|
| 257 |
+
return None
|
| 258 |
+
with self._open_file(fname) as f:
|
| 259 |
+
labels = json.load(f)["labels"]
|
| 260 |
+
if labels is None:
|
| 261 |
+
return None
|
| 262 |
+
labels = dict(labels)
|
| 263 |
+
labels = [labels[fname.replace("\\", "/")] for fname in self._image_fnames]
|
| 264 |
+
labels = np.array(labels)
|
| 265 |
+
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
|
| 266 |
+
return labels
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# ----------------------------------------------------------------------------
|
flowutils.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pdb
|
| 2 |
+
|
| 3 |
+
import normflows as nf
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_flows(
|
| 11 |
+
latent_size, num_flows=4, num_blocks=2, hidden_units=128, context_size=64
|
| 12 |
+
):
|
| 13 |
+
# Define flows
|
| 14 |
+
|
| 15 |
+
flows = []
|
| 16 |
+
for i in range(num_flows):
|
| 17 |
+
flows += [
|
| 18 |
+
nf.flows.CoupledRationalQuadraticSpline(
|
| 19 |
+
latent_size,
|
| 20 |
+
num_blocks=num_blocks,
|
| 21 |
+
num_hidden_channels=hidden_units,
|
| 22 |
+
num_context_channels=context_size,
|
| 23 |
+
)
|
| 24 |
+
]
|
| 25 |
+
flows += [nf.flows.LULinearPermute(latent_size)]
|
| 26 |
+
|
| 27 |
+
# Set base distribution
|
| 28 |
+
q0 = nf.distributions.DiagGaussian(latent_size, trainable=True)
|
| 29 |
+
|
| 30 |
+
# Construct flow model
|
| 31 |
+
model = nf.ConditionalNormalizingFlow(q0, flows)
|
| 32 |
+
|
| 33 |
+
return model
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_emb(sin_inp):
|
| 37 |
+
"""
|
| 38 |
+
Gets a base embedding for one dimension with sin and cos intertwined
|
| 39 |
+
"""
|
| 40 |
+
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
|
| 41 |
+
return torch.flatten(emb, -2, -1)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class PositionalEncoding2D(nn.Module):
|
| 45 |
+
def __init__(self, channels):
|
| 46 |
+
"""
|
| 47 |
+
:param channels: The last dimension of the tensor you want to apply pos emb to.
|
| 48 |
+
"""
|
| 49 |
+
super(PositionalEncoding2D, self).__init__()
|
| 50 |
+
self.org_channels = channels
|
| 51 |
+
channels = int(np.ceil(channels / 4) * 2)
|
| 52 |
+
self.channels = channels
|
| 53 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
|
| 54 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 55 |
+
self.register_buffer("cached_penc", None, persistent=False)
|
| 56 |
+
|
| 57 |
+
def forward(self, tensor):
|
| 58 |
+
"""
|
| 59 |
+
:param tensor: A 4d tensor of size (batch_size, x, y, ch)
|
| 60 |
+
:return: Positional Encoding Matrix of size (batch_size, x, y, ch)
|
| 61 |
+
"""
|
| 62 |
+
if len(tensor.shape) != 4:
|
| 63 |
+
raise RuntimeError("The input tensor has to be 4d!")
|
| 64 |
+
|
| 65 |
+
if (
|
| 66 |
+
self.cached_penc is not None
|
| 67 |
+
and self.cached_penc.shape[:2] == tensor.shape[1:3]
|
| 68 |
+
):
|
| 69 |
+
return self.cached_penc
|
| 70 |
+
|
| 71 |
+
self.cached_penc = None
|
| 72 |
+
batch_size, orig_ch, x, y = tensor.shape
|
| 73 |
+
pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
|
| 74 |
+
pos_y = torch.arange(y, device=tensor.device, dtype=self.inv_freq.dtype)
|
| 75 |
+
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
|
| 76 |
+
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
|
| 77 |
+
emb_x = get_emb(sin_inp_x).unsqueeze(1)
|
| 78 |
+
emb_y = get_emb(sin_inp_y)
|
| 79 |
+
emb = torch.zeros(
|
| 80 |
+
(x, y, self.channels * 2),
|
| 81 |
+
device=tensor.device,
|
| 82 |
+
dtype=tensor.dtype,
|
| 83 |
+
)
|
| 84 |
+
emb[:, :, : self.channels] = emb_x
|
| 85 |
+
emb[:, :, self.channels : 2 * self.channels] = emb_y
|
| 86 |
+
|
| 87 |
+
self.cached_penc = emb
|
| 88 |
+
|
| 89 |
+
return self.cached_penc
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class SpatialNormer(nn.Module):
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
in_channels, # channels will be number of sigma scales in input
|
| 96 |
+
kernel_size=3,
|
| 97 |
+
stride=2,
|
| 98 |
+
padding=1,
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
Note that the convolution will reduce the channel dimension
|
| 102 |
+
So (b, num_sigmas, c, h, w) -> (b, num_sigmas, new_h , new_w)
|
| 103 |
+
"""
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.conv = nn.Conv3d(
|
| 106 |
+
in_channels,
|
| 107 |
+
in_channels,
|
| 108 |
+
kernel_size,
|
| 109 |
+
# This is the real trick that ensures each
|
| 110 |
+
# sigma dimension is normed separately
|
| 111 |
+
groups=in_channels,
|
| 112 |
+
stride=(1, stride, stride),
|
| 113 |
+
padding=(0, padding, padding),
|
| 114 |
+
bias=False,
|
| 115 |
+
)
|
| 116 |
+
self.conv.weight.data.fill_(1) # all ones weights
|
| 117 |
+
self.conv.weight.requires_grad = False # freeze weights
|
| 118 |
+
|
| 119 |
+
@torch.no_grad()
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
return self.conv(x.square()).pow_(0.5).squeeze(2)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class PatchFlow(torch.nn.Module):
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
input_size,
|
| 128 |
+
patch_size=3,
|
| 129 |
+
context_embedding_size=128,
|
| 130 |
+
num_blocks=2,
|
| 131 |
+
hidden_units=128,
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
num_sigmas, c, h, w = input_size
|
| 135 |
+
self.local_pooler = SpatialNormer(
|
| 136 |
+
in_channels=num_sigmas, kernel_size=patch_size
|
| 137 |
+
)
|
| 138 |
+
self.flow = build_flows(
|
| 139 |
+
latent_size=num_sigmas, context_size=context_embedding_size
|
| 140 |
+
)
|
| 141 |
+
self.position_encoding = PositionalEncoding2D(channels=context_embedding_size)
|
| 142 |
+
|
| 143 |
+
# caching pos encs
|
| 144 |
+
_, _, ctx_h, ctw_w = self.local_pooler(
|
| 145 |
+
torch.empty((1, num_sigmas, c, h, w))
|
| 146 |
+
).shape
|
| 147 |
+
self.position_encoding(torch.empty(1, 1, ctx_h, ctw_w))
|
| 148 |
+
assert self.position_encoding.cached_penc.shape[-1] == context_embedding_size
|
| 149 |
+
|
| 150 |
+
def init_weights(self):
|
| 151 |
+
# Initialize weights with Xavier
|
| 152 |
+
linear_modules = list(
|
| 153 |
+
filter(lambda m: isinstance(m, nn.Linear), self.flow.modules())
|
| 154 |
+
)
|
| 155 |
+
total = len(linear_modules)
|
| 156 |
+
|
| 157 |
+
for idx, m in enumerate(linear_modules):
|
| 158 |
+
# Last layer gets init w/ zeros
|
| 159 |
+
if idx == total - 1:
|
| 160 |
+
nn.init.zeros_(m.weight.data)
|
| 161 |
+
else:
|
| 162 |
+
nn.init.xavier_uniform_(m.weight.data)
|
| 163 |
+
|
| 164 |
+
if m.bias is not None:
|
| 165 |
+
nn.init.zeros_(m.bias.data)
|
| 166 |
+
|
| 167 |
+
def forward(self, x, chunk_size=32):
|
| 168 |
+
b, s, c, h, w = x.shape
|
| 169 |
+
x_norm = self.local_pooler(x)
|
| 170 |
+
_, _, new_h, new_w = x_norm.shape
|
| 171 |
+
context = self.position_encoding(x_norm)
|
| 172 |
+
|
| 173 |
+
# (Patches * batch) x channels
|
| 174 |
+
local_ctx = rearrange(context, "h w c -> (h w) c")
|
| 175 |
+
patches = rearrange(x_norm, "b c h w -> (h w) b c")
|
| 176 |
+
|
| 177 |
+
nchunks = (patches.shape[0] + chunk_size - 1) // chunk_size
|
| 178 |
+
patches = patches.chunk(nchunks, dim=0)
|
| 179 |
+
ctx_chunks = local_ctx.chunk(nchunks, dim=0)
|
| 180 |
+
patch_logpx = []
|
| 181 |
+
|
| 182 |
+
# gc = repeat(global_ctx, "b c -> (n b) c", n=self.patch_batch_size)
|
| 183 |
+
|
| 184 |
+
for p, ctx in zip(patches, ctx_chunks):
|
| 185 |
+
|
| 186 |
+
# num patches in chunk (<= chunk_size)
|
| 187 |
+
n = p.shape[0]
|
| 188 |
+
ctx = repeat(ctx, "n c -> (n b) c", b=b)
|
| 189 |
+
p = rearrange(p, "n b c -> (n b) c")
|
| 190 |
+
|
| 191 |
+
# Compute log densities for each patch
|
| 192 |
+
logpx = self.flow.log_prob(p, context=ctx)
|
| 193 |
+
logpx = rearrange(logpx, "(n b) -> n b", n=n, b=b)
|
| 194 |
+
patch_logpx.append(logpx)
|
| 195 |
+
# del ctx, p
|
| 196 |
+
|
| 197 |
+
# print(p[:4], ctx[:4], logpx)
|
| 198 |
+
# Convert back to image
|
| 199 |
+
logpx = torch.cat(patch_logpx, dim=0)
|
| 200 |
+
logpx = rearrange(logpx, "(h w) b -> b 1 h w", b=b, h=new_h, w=new_w)
|
| 201 |
+
|
| 202 |
+
return logpx.contiguous()
|
| 203 |
+
|
| 204 |
+
@staticmethod
|
| 205 |
+
def stochastic_step(
|
| 206 |
+
scores, x_batch, flow_model, opt=None, train=False, n_patches=32, device="cpu"
|
| 207 |
+
):
|
| 208 |
+
if train:
|
| 209 |
+
flow_model.train()
|
| 210 |
+
opt.zero_grad(set_to_none=True)
|
| 211 |
+
else:
|
| 212 |
+
flow_model.eval()
|
| 213 |
+
|
| 214 |
+
patches, context = PatchFlow.get_random_patches(
|
| 215 |
+
scores, x_batch, flow_model, n_patches
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
patch_feature = patches.to(device)
|
| 219 |
+
context_vector = context.to(device)
|
| 220 |
+
patch_feature = rearrange(patch_feature, "n b c -> (n b) c")
|
| 221 |
+
context_vector = rearrange(context_vector, "n b c -> (n b) c")
|
| 222 |
+
|
| 223 |
+
# global_pooled_image = flow_model.global_pooler(x_batch)
|
| 224 |
+
# global_context = flow_model.global_attention(global_pooled_image)
|
| 225 |
+
# gctx = repeat(global_context, "b c -> (n b) c", n=n_patches)
|
| 226 |
+
|
| 227 |
+
# # Concatenate global context to local context
|
| 228 |
+
# context_vector = torch.cat([context_vector, gctx], dim=1)
|
| 229 |
+
|
| 230 |
+
z, ldj = flow_model.flow.inverse_and_log_det(
|
| 231 |
+
patch_feature,
|
| 232 |
+
context=context_vector,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
loss = -torch.mean(flow_model.flow.q0.log_prob(z) + ldj)
|
| 236 |
+
loss *= n_patches
|
| 237 |
+
|
| 238 |
+
if train:
|
| 239 |
+
loss.backward()
|
| 240 |
+
opt.step()
|
| 241 |
+
|
| 242 |
+
return loss.item() / n_patches
|
| 243 |
+
|
| 244 |
+
@staticmethod
|
| 245 |
+
def get_random_patches(scores, x_batch, flow_model, n_patches):
|
| 246 |
+
b = scores.shape[0]
|
| 247 |
+
h = flow_model.local_pooler(scores)
|
| 248 |
+
patches = rearrange(h, "b c h w -> (h w) b c")
|
| 249 |
+
|
| 250 |
+
context = flow_model.position_encoding(h)
|
| 251 |
+
context = rearrange(context, "h w c -> (h w) c")
|
| 252 |
+
context = repeat(context, "n c -> n b c", b=b)
|
| 253 |
+
|
| 254 |
+
# conserve gpu memory
|
| 255 |
+
patches = patches.cpu()
|
| 256 |
+
context = context.cpu()
|
| 257 |
+
|
| 258 |
+
# Get random patches
|
| 259 |
+
total_patches = patches.shape[0]
|
| 260 |
+
shuffled_idx = torch.randperm(total_patches)
|
| 261 |
+
rand_idx_batch = shuffled_idx[:n_patches]
|
| 262 |
+
|
| 263 |
+
return patches[rand_idx_batch], context[rand_idx_batch]
|
scorer.py → msma.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import pickle
|
|
|
|
| 3 |
from pickle import dump, load
|
| 4 |
|
| 5 |
import numpy as np
|
|
@@ -9,9 +10,12 @@ from sklearn.mixture import GaussianMixture
|
|
| 9 |
from sklearn.model_selection import GridSearchCV
|
| 10 |
from sklearn.pipeline import Pipeline
|
| 11 |
from sklearn.preprocessing import StandardScaler
|
|
|
|
| 12 |
from tqdm import tqdm
|
| 13 |
|
| 14 |
import dnnlib
|
|
|
|
|
|
|
| 15 |
|
| 16 |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
|
| 17 |
|
|
@@ -22,6 +26,17 @@ config_presets = {
|
|
| 22 |
}
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
class EDMScorer(torch.nn.Module):
|
| 26 |
def __init__(
|
| 27 |
self,
|
|
@@ -41,6 +56,7 @@ class EDMScorer(torch.nn.Module):
|
|
| 41 |
self.sigma_max = sigma_max
|
| 42 |
self.sigma_data = sigma_data
|
| 43 |
self.net = net.eval()
|
|
|
|
| 44 |
|
| 45 |
# Adjust noise levels based on how far we want to accumulate
|
| 46 |
self.sigma_min = 1e-1
|
|
@@ -63,7 +79,7 @@ class EDMScorer(torch.nn.Module):
|
|
| 63 |
x,
|
| 64 |
force_fp32=False,
|
| 65 |
):
|
| 66 |
-
x = x.to(torch.float32)
|
| 67 |
|
| 68 |
batch_scores = []
|
| 69 |
for sigma in self.sigma_steps:
|
|
@@ -76,6 +92,29 @@ class EDMScorer(torch.nn.Module):
|
|
| 76 |
return batch_scores
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def build_model(preset="edm2-img64-s-fid", device="cpu"):
|
| 80 |
netpath = config_presets[preset]
|
| 81 |
with dnnlib.util.open_url(netpath, verbose=1) as f:
|
|
@@ -85,41 +124,45 @@ def build_model(preset="edm2-img64-s-fid", device="cpu"):
|
|
| 85 |
return model
|
| 86 |
|
| 87 |
|
| 88 |
-
def
|
| 89 |
-
|
| 90 |
-
return np.quantile(gmm.score_samples(X), 0.1)
|
| 91 |
|
| 92 |
-
X = torch.load(score_path)
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
clf.fit(X)
|
| 97 |
-
inlier_nll = -clf.score_samples(X)
|
| 98 |
-
|
| 99 |
-
param_grid = dict(
|
| 100 |
-
GMM__n_components=range(2, 11, 2),
|
| 101 |
-
)
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
param_grid=param_grid,
|
| 106 |
-
cv=10,
|
| 107 |
-
n_jobs=2,
|
| 108 |
-
verbose=1,
|
| 109 |
-
scoring=quantile_scorer,
|
| 110 |
)
|
| 111 |
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
means = grid_result.cv_results_["mean_test_score"]
|
| 117 |
-
stds = grid_result.cv_results_["std_test_score"]
|
| 118 |
-
params = grid_result.cv_results_["params"]
|
| 119 |
-
for mean, stdev, param in zip(means, stds, params):
|
| 120 |
-
print("%f (%f) with: %r" % (mean, stdev, param))
|
| 121 |
-
|
| 122 |
-
clf = grid.best_estimator_
|
| 123 |
|
| 124 |
os.makedirs(outdir, exist_ok=True)
|
| 125 |
with open(f"{outdir}/refscores.npz", "wb") as f:
|
|
@@ -134,26 +177,14 @@ def compute_gmm_likelihood(x_score, gmmdir):
|
|
| 134 |
clf = load(f)
|
| 135 |
nll = -clf.score_samples(x_score)
|
| 136 |
|
| 137 |
-
with np.load(f"{gmmdir}/refscores.npz", "
|
| 138 |
ref_nll = f["arr_0"]
|
| 139 |
percentile = (ref_nll < nll).mean()
|
| 140 |
|
| 141 |
return nll, percentile
|
| 142 |
|
| 143 |
|
| 144 |
-
def
|
| 145 |
-
# f = "doge.jpg"
|
| 146 |
-
f = "goldfish.JPEG"
|
| 147 |
-
image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
|
| 148 |
-
image = np.array(image)
|
| 149 |
-
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
| 150 |
-
x = torch.from_numpy(image).unsqueeze(0).to(device)
|
| 151 |
-
model = build_model(device=device)
|
| 152 |
-
scores = model(x)
|
| 153 |
-
return scores
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def runner(preset, dataset_path, device="cpu"):
|
| 157 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
| 158 |
refimg, reflabel = dsobj[0]
|
| 159 |
print(refimg.shape, refimg.dtype, reflabel)
|
|
@@ -178,19 +209,138 @@ def runner(preset, dataset_path, device="cpu"):
|
|
| 178 |
print(f"Computed score norms for {score_norms.shape[0]} samples")
|
| 179 |
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
if __name__ == "__main__":
|
| 182 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 183 |
preset = "edm2-img64-s-fid"
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
# preset=preset,
|
| 186 |
# dataset_path="/GROND_STOR/amahmood/datasets/img64/",
|
| 187 |
# device="cuda",
|
| 188 |
# )
|
| 189 |
-
train_gmm(
|
| 190 |
-
|
| 191 |
-
)
|
| 192 |
-
s = test_runner(device=device)
|
| 193 |
-
s = s.square().sum(dim=(2, 3, 4)) ** 0.5
|
| 194 |
-
s = s.to("cpu").numpy()
|
| 195 |
-
nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}")
|
| 196 |
-
print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
|
|
|
|
| 1 |
import os
|
| 2 |
import pickle
|
| 3 |
+
from functools import partial
|
| 4 |
from pickle import dump, load
|
| 5 |
|
| 6 |
import numpy as np
|
|
|
|
| 10 |
from sklearn.model_selection import GridSearchCV
|
| 11 |
from sklearn.pipeline import Pipeline
|
| 12 |
from sklearn.preprocessing import StandardScaler
|
| 13 |
+
from torch.utils.data import Subset
|
| 14 |
from tqdm import tqdm
|
| 15 |
|
| 16 |
import dnnlib
|
| 17 |
+
from dataset import ImageFolderDataset
|
| 18 |
+
from flowutils import PatchFlow
|
| 19 |
|
| 20 |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
|
| 21 |
|
|
|
|
| 26 |
}
|
| 27 |
|
| 28 |
|
| 29 |
+
class StandardRGBEncoder:
|
| 30 |
+
def __init__(self):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
def encode(self, x): # raw pixels => final pixels
|
| 34 |
+
return x.to(torch.float32) / 127.5 - 1
|
| 35 |
+
|
| 36 |
+
def decode(self, x): # final latents => raw pixels
|
| 37 |
+
return (x.to(torch.float32) * 127.5 + 128).clip(0, 255).to(torch.uint8)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
class EDMScorer(torch.nn.Module):
|
| 41 |
def __init__(
|
| 42 |
self,
|
|
|
|
| 56 |
self.sigma_max = sigma_max
|
| 57 |
self.sigma_data = sigma_data
|
| 58 |
self.net = net.eval()
|
| 59 |
+
self.encoder = StandardRGBEncoder()
|
| 60 |
|
| 61 |
# Adjust noise levels based on how far we want to accumulate
|
| 62 |
self.sigma_min = 1e-1
|
|
|
|
| 79 |
x,
|
| 80 |
force_fp32=False,
|
| 81 |
):
|
| 82 |
+
x = self.encoder.encode(x).to(torch.float32)
|
| 83 |
|
| 84 |
batch_scores = []
|
| 85 |
for sigma in self.sigma_steps:
|
|
|
|
| 92 |
return batch_scores
|
| 93 |
|
| 94 |
|
| 95 |
+
class ScoreFlow(torch.nn.Module):
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
scorenet,
|
| 99 |
+
vectorize=False,
|
| 100 |
+
device="cpu",
|
| 101 |
+
):
|
| 102 |
+
super().__init__()
|
| 103 |
+
|
| 104 |
+
h = w = scorenet.net.img_resolution
|
| 105 |
+
c = scorenet.net.img_channels
|
| 106 |
+
num_sigmas = len(scorenet.sigma_steps)
|
| 107 |
+
self.flow = PatchFlow((num_sigmas, c, h, w))
|
| 108 |
+
|
| 109 |
+
self.flow = self.flow.to(device)
|
| 110 |
+
self.scorenet = scorenet.to(device).requires_grad_(False)
|
| 111 |
+
self.flow.init_weights()
|
| 112 |
+
|
| 113 |
+
def forward(self, x, **score_kwargs):
|
| 114 |
+
x_scores = self.scorenet(x, **score_kwargs)
|
| 115 |
+
return self.flow(x_scores)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
def build_model(preset="edm2-img64-s-fid", device="cpu"):
|
| 119 |
netpath = config_presets[preset]
|
| 120 |
with dnnlib.util.open_url(netpath, verbose=1) as f:
|
|
|
|
| 124 |
return model
|
| 125 |
|
| 126 |
|
| 127 |
+
def quantile_scorer(gmm, X, y=None):
|
| 128 |
+
return np.quantile(gmm.score_samples(X), 0.1)
|
|
|
|
| 129 |
|
|
|
|
| 130 |
|
| 131 |
+
def train_gmm(score_path, outdir, grid_search=False):
|
| 132 |
+
X = torch.load(score_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
gm = GaussianMixture(
|
| 135 |
+
n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
)
|
| 137 |
|
| 138 |
+
if grid_search:
|
| 139 |
+
clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
|
| 140 |
+
param_grid = dict(
|
| 141 |
+
GMM__n_components=range(2, 11, 1),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
grid = GridSearchCV(
|
| 145 |
+
estimator=clf,
|
| 146 |
+
param_grid=param_grid,
|
| 147 |
+
cv=5,
|
| 148 |
+
n_jobs=2,
|
| 149 |
+
verbose=1,
|
| 150 |
+
scoring=quantile_scorer,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
grid_result = grid.fit(X)
|
| 154 |
+
|
| 155 |
+
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
|
| 156 |
+
print("-----" * 15)
|
| 157 |
+
means = grid_result.cv_results_["mean_test_score"]
|
| 158 |
+
stds = grid_result.cv_results_["std_test_score"]
|
| 159 |
+
params = grid_result.cv_results_["params"]
|
| 160 |
+
for mean, stdev, param in zip(means, stds, params):
|
| 161 |
+
print("%f (%f) with: %r" % (mean, stdev, param))
|
| 162 |
+
clf = grid.best_estimator_
|
| 163 |
|
| 164 |
+
clf.fit(X)
|
| 165 |
+
inlier_nll = -clf.score_samples(X)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
os.makedirs(outdir, exist_ok=True)
|
| 168 |
with open(f"{outdir}/refscores.npz", "wb") as f:
|
|
|
|
| 177 |
clf = load(f)
|
| 178 |
nll = -clf.score_samples(x_score)
|
| 179 |
|
| 180 |
+
with np.load(f"{gmmdir}/refscores.npz", "rb") as f:
|
| 181 |
ref_nll = f["arr_0"]
|
| 182 |
percentile = (ref_nll < nll).mean()
|
| 183 |
|
| 184 |
return nll, percentile
|
| 185 |
|
| 186 |
|
| 187 |
+
def cache_score_norms(preset, dataset_path, device="cpu"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
| 189 |
refimg, reflabel = dsobj[0]
|
| 190 |
print(refimg.shape, refimg.dtype, reflabel)
|
|
|
|
| 209 |
print(f"Computed score norms for {score_norms.shape[0]} samples")
|
| 210 |
|
| 211 |
|
| 212 |
+
def train_flow(dataset_path, preset, device="cuda"):
|
| 213 |
+
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
| 214 |
+
refimg, reflabel = dsobj[0]
|
| 215 |
+
print(f"Loaded {len(dsobj)} samples from {dataset_path}")
|
| 216 |
+
|
| 217 |
+
# Subset of training dataset
|
| 218 |
+
val_ratio = 0.1
|
| 219 |
+
train_len = int((1 - val_ratio) * len(dsobj))
|
| 220 |
+
val_len = len(dsobj) - train_len
|
| 221 |
+
|
| 222 |
+
print(
|
| 223 |
+
f"Generating train/test split with ratio={val_ratio} -> {train_len}/{val_len}..."
|
| 224 |
+
)
|
| 225 |
+
train_ds = Subset(dsobj, range(train_len))
|
| 226 |
+
val_ds = Subset(dsobj, range(train_len, train_len + val_len))
|
| 227 |
+
|
| 228 |
+
trainiter = torch.utils.data.DataLoader(
|
| 229 |
+
train_ds, batch_size=48, num_workers=4, prefetch_factor=2
|
| 230 |
+
)
|
| 231 |
+
testiter = torch.utils.data.DataLoader(
|
| 232 |
+
val_ds, batch_size=48, num_workers=4, prefetch_factor=2
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
model = ScoreFlow(build_model(preset=preset), device=device)
|
| 236 |
+
opt = torch.optim.AdamW(model.flow.parameters(), lr=3e-4, weight_decay=1e-5)
|
| 237 |
+
train_step = partial(
|
| 238 |
+
PatchFlow.stochastic_step,
|
| 239 |
+
flow_model=model.flow,
|
| 240 |
+
opt=opt,
|
| 241 |
+
train=True,
|
| 242 |
+
n_patches=64,
|
| 243 |
+
device=device,
|
| 244 |
+
)
|
| 245 |
+
eval_step = partial(
|
| 246 |
+
PatchFlow.stochastic_step,
|
| 247 |
+
flow_model=model.flow,
|
| 248 |
+
train=False,
|
| 249 |
+
n_patches=128,
|
| 250 |
+
device=device,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
pbar = tqdm(trainiter, desc="Train Loss: ? - Val Loss: ?")
|
| 254 |
+
step = 0
|
| 255 |
+
|
| 256 |
+
for x, _ in tqdm(trainiter):
|
| 257 |
+
x = x.to(device)
|
| 258 |
+
scores = model.scorenet(x)
|
| 259 |
+
|
| 260 |
+
if step == 0:
|
| 261 |
+
with torch.inference_mode():
|
| 262 |
+
val_loss = eval_step(scores, x)
|
| 263 |
+
|
| 264 |
+
train_loss = train_step(scores, x)
|
| 265 |
+
|
| 266 |
+
if (step + 1) % 10 == 0:
|
| 267 |
+
|
| 268 |
+
with torch.inference_mode():
|
| 269 |
+
val_loss = 0.0
|
| 270 |
+
for i, (x, _) in enumerate(testiter):
|
| 271 |
+
x = x.to(device)
|
| 272 |
+
scores = model.scorenet(x)
|
| 273 |
+
val_loss += eval_step(scores, x)
|
| 274 |
+
break
|
| 275 |
+
val_loss /= i + 1
|
| 276 |
+
|
| 277 |
+
pbar.set_description(
|
| 278 |
+
f"Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}"
|
| 279 |
+
)
|
| 280 |
+
step += 1
|
| 281 |
+
|
| 282 |
+
torch.save(model.flow.state_dict(), f"out/msma/{preset}/flow.pt")
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
@torch.inference_mode
|
| 286 |
+
def test_runner(device="cpu"):
|
| 287 |
+
# f = "doge.jpg"
|
| 288 |
+
f = "goldfish.JPEG"
|
| 289 |
+
image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
|
| 290 |
+
image = np.array(image)
|
| 291 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
| 292 |
+
x = torch.from_numpy(image).unsqueeze(0).to(device)
|
| 293 |
+
model = build_model(device=device)
|
| 294 |
+
scores = model(x)
|
| 295 |
+
|
| 296 |
+
return scores
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def test_flow_runner(device="cpu", load_weights=None):
|
| 300 |
+
f = "doge.jpg"
|
| 301 |
+
# f = "goldfish.JPEG"
|
| 302 |
+
image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
|
| 303 |
+
image = np.array(image)
|
| 304 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
| 305 |
+
x = torch.from_numpy(image).unsqueeze(0).to(device)
|
| 306 |
+
model = build_model(device=device)
|
| 307 |
+
|
| 308 |
+
score_flow = ScoreFlow(scorenet=model, device=device)
|
| 309 |
+
|
| 310 |
+
if load_weights is not None:
|
| 311 |
+
score_flow.flow.load_state_dict(torch.load(load_weights))
|
| 312 |
+
|
| 313 |
+
heatmap = score_flow(x)
|
| 314 |
+
print(heatmap.shape)
|
| 315 |
+
|
| 316 |
+
heatmap = score_flow(x).detach().cpu().numpy()
|
| 317 |
+
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) * 255
|
| 318 |
+
im = PIL.Image.fromarray(heatmap[0, 0])
|
| 319 |
+
im.convert("RGB").save(
|
| 320 |
+
"heatmap.png",
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
return
|
| 324 |
+
|
| 325 |
+
|
| 326 |
if __name__ == "__main__":
|
| 327 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 328 |
preset = "edm2-img64-s-fid"
|
| 329 |
+
imagenette_path = "/GROND_STOR/amahmood/datasets/img64/"
|
| 330 |
+
|
| 331 |
+
train_flow(imagenette_path, preset, device)
|
| 332 |
+
test_flow_runner("cuda", f"out/msma/{preset}/flow.pt")
|
| 333 |
+
|
| 334 |
+
# cache_score_norms(
|
| 335 |
# preset=preset,
|
| 336 |
# dataset_path="/GROND_STOR/amahmood/datasets/img64/",
|
| 337 |
# device="cuda",
|
| 338 |
# )
|
| 339 |
+
# train_gmm(
|
| 340 |
+
# f"out/msma/{preset}_imagenette_score_norms.pt", outdir=f"out/msma/{preset}"
|
| 341 |
+
# )
|
| 342 |
+
# s = test_runner(device=device)
|
| 343 |
+
# s = s.square().sum(dim=(2, 3, 4)) ** 0.5
|
| 344 |
+
# s = s.to("cpu").numpy()
|
| 345 |
+
# nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
|
| 346 |
+
# print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
|