onullusoy commited on
Commit
5e94db5
·
verified ·
1 Parent(s): 5f57e45

Upload 12 files

Browse files
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from PIL import Image
4
+ from model_registry import (
5
+ ALL_CATEGORIES, DEFAULT_THRESHOLD, REGISTRY, get_model, NUDENET_ONLY
6
+ )
7
+ from video_utils import (
8
+ has_ffmpeg, probe_duration, extract_frames_ffmpeg, runs_from_indices,
9
+ merge_seconds_union, redact_with_ffmpeg
10
+ )
11
+
12
+ APP_TITLE = "Content Moderation Demo (Image & Video)"
13
+ APP_DESC = """
14
+ Minimal prototype: image/video analysis, model & category selection, and threshold control.
15
+ """
16
+
17
+ MODEL_NAMES = list(REGISTRY.keys())
18
+
19
+ # ---------- Shared ----------
20
+ def on_model_change(model_name):
21
+ if model_name in NUDENET_ONLY:
22
+ cats_state = gr.CheckboxGroup(choices=["sexual"], value=["sexual"], interactive=False, label="Categories")
23
+ else:
24
+ cats_state = gr.CheckboxGroup(choices=ALL_CATEGORIES, value=ALL_CATEGORIES, interactive=True, label="Categories")
25
+ th = DEFAULT_THRESHOLD
26
+ return cats_state, gr.Slider(minimum=0.0, maximum=1.0, value=th, step=0.01, label="Threshold")
27
+
28
+ # ---------- Image ----------
29
+ def analyze_image(model_name, image, selected_categories, threshold):
30
+ if image is None:
31
+ return "No image.", None, gr.update(visible=False)
32
+ pil = Image.fromarray(image) if not isinstance(image, Image.Image) else image
33
+ model = get_model(model_name)
34
+ allowed = set(getattr(model, "categories", ALL_CATEGORIES))
35
+ req = [c for c in selected_categories if c in allowed]
36
+ if not req:
37
+ return "No categories selected.", None, gr.update(visible=False)
38
+ scores = model.predict_image(pil, req)
39
+ verdict = "RISKY" if any(v >= threshold for v in scores.values()) else "SAFE"
40
+ df = pd.DataFrame([{"category": k, "score": f"{(float(v)*100):.1f}%"} for k, v in sorted(scores.items())])
41
+ if getattr(model, "supports_selected_tags", False):
42
+ extra = model.extra_selected_tags(pil, top_k=15)
43
+ txt = "\n".join(f"- {t}: {s:.3f}" for t, s in extra)
44
+ return verdict, df, gr.update(visible=True, value=txt)
45
+ else:
46
+ return verdict, df, gr.update(visible=False)
47
+
48
+ # ---------- Video ----------
49
+ def analyze_video(model_name, video_file, selected_categories, threshold, sampling_fps, redact):
50
+ import tempfile, os, shutil
51
+
52
+ if video_file is None:
53
+ return pd.DataFrame([{"segment":"Error: No video."}]), gr.update(value=None)
54
+
55
+ dur = probe_duration(video_file)
56
+ if dur is not None and dur > 60.0:
57
+ return pd.DataFrame([{"segment":"Error: Video too long (limit: 60s)."}]), gr.update(value=None)
58
+
59
+ model = get_model(model_name)
60
+ allowed = set(getattr(model, "categories", ALL_CATEGORIES))
61
+ req = [c for c in selected_categories if c in allowed]
62
+ if not req:
63
+ return pd.DataFrame([{"segment":"Error: No categories selected."}]), gr.update(value=None)
64
+
65
+ with tempfile.TemporaryDirectory() as td:
66
+ try:
67
+ frames = extract_frames_ffmpeg(video_file, sampling_fps, os.path.join(td, "frames"))
68
+ except Exception:
69
+ return pd.DataFrame([{"segment":"Error: FFmpeg not available or failed to extract frames."}]), gr.update(value=None)
70
+
71
+ all_hit_idx: list[int] = []
72
+ frame_stats: dict[int, dict] = {}
73
+
74
+ for fp, idx in frames:
75
+ with Image.open(fp) as im:
76
+ pil = im.convert("RGB")
77
+ scores = model.predict_image(pil, req)
78
+ over = {c: float(scores.get(c, 0.0)) for c in req if float(scores.get(c, 0.0)) >= threshold}
79
+ if over:
80
+ all_hit_idx.append(idx)
81
+ peak_cat, peak_p = max(over.items(), key=lambda kv: kv[1])
82
+ frame_stats[idx] = {"hits": over, "peak_cat": peak_cat, "peak_p": peak_p}
83
+
84
+ if not all_hit_idx:
85
+ return pd.DataFrame([{"segment":"(no hits)"}]), gr.update(value=None)
86
+
87
+ union_runs = runs_from_indices(sorted(set(all_hit_idx)))
88
+
89
+ rows = []
90
+ for seg_id, (a, b) in enumerate(union_runs, start=1):
91
+ for i in range(a, b + 1):
92
+ st = frame_stats.get(i)
93
+ if not st:
94
+ continue
95
+
96
+ cat_counts = {c: 0 for c in req}
97
+ cat_maxp = {c: 0.0 for c in req}
98
+ for i in range(a, b + 1):
99
+ st = frame_stats.get(i)
100
+ if not st:
101
+ continue
102
+ for c, p in st["hits"].items():
103
+ cat_counts[c] += 1
104
+ if p > cat_maxp[c]:
105
+ cat_maxp[c] = p
106
+
107
+ present = [c for c in req if cat_counts[c] > 0]
108
+ present.sort(key=lambda c: (-cat_counts[c], -cat_maxp[c], c))
109
+
110
+ for c in present:
111
+ rows.append({
112
+ "seg": seg_id,
113
+ "start": round(a / sampling_fps, 3),
114
+ "end": round((b + 1) / sampling_fps, 3),
115
+ "category": c,
116
+ "max_p": round(cat_maxp[c], 3),
117
+ })
118
+
119
+ df = pd.DataFrame(rows).sort_values(["seg", "max_p"], ascending=[True, False]).reset_index(drop=True)
120
+
121
+ out_video = gr.update(value=None)
122
+ if redact and has_ffmpeg():
123
+ intervals = merge_seconds_union(all_hit_idx, sampling_fps, pad=0.25)
124
+ try:
125
+ out_path = os.path.join(td, "redacted.mp4")
126
+ redact_with_ffmpeg(video_file, intervals, out_path)
127
+ final_out = os.path.join(os.getcwd(), "redacted_output.mp4")
128
+ shutil.copyfile(out_path, final_out)
129
+ out_video = gr.update(value=final_out)
130
+ except Exception:
131
+ out_video = gr.update(value=None)
132
+
133
+ return df, out_video
134
+
135
+
136
+ # ---------- UI ----------
137
+ with gr.Blocks(title=APP_TITLE, css=".wrap-row { gap: 16px; }") as demo:
138
+ gr.Markdown(f"# {APP_TITLE}")
139
+ gr.Markdown(APP_DESC)
140
+
141
+ with gr.Tabs():
142
+ with gr.Tab("Image"):
143
+ with gr.Row(elem_classes=["wrap-row"]):
144
+ with gr.Column(scale=1, min_width=360):
145
+ model_dd = gr.Dropdown(label="Model", choices=MODEL_NAMES, value=MODEL_NAMES[0])
146
+ threshold = gr.Slider(0.0, 1.0, value=DEFAULT_THRESHOLD, step=0.01, label="Threshold")
147
+ categories = gr.CheckboxGroup(label="Categories", choices=ALL_CATEGORIES, value=ALL_CATEGORIES)
148
+ inp_img = gr.Image(type="pil", label="Upload Image")
149
+ btn = gr.Button("Analyze", variant="primary")
150
+ with gr.Column(scale=1, min_width=360):
151
+ verdict = gr.Label(label="Verdict")
152
+ scores_df = gr.Dataframe(headers=["category", "score"], datatype="str",
153
+ label="Scores", interactive=False)
154
+ extra_tags = gr.Textbox(label="Selected tags", visible=False, lines=12)
155
+
156
+ model_dd.change(on_model_change, inputs=model_dd, outputs=[categories, threshold])
157
+ btn.click(analyze_image, inputs=[model_dd, inp_img, categories, threshold],
158
+ outputs=[verdict, scores_df, extra_tags])
159
+
160
+ with gr.Tab("Video"):
161
+ with gr.Row(elem_classes=["wrap-row"]):
162
+ with gr.Column(scale=1, min_width=360):
163
+ v_model = gr.Dropdown(label="Model", choices=MODEL_NAMES, value=MODEL_NAMES[0])
164
+ v_threshold = gr.Slider(0.0, 1.0, value=DEFAULT_THRESHOLD,
165
+ step=0.01, label="Threshold")
166
+ v_fps = gr.Slider(0.25, 5.0, value=1.0, step=0.25, label="Sampling FPS")
167
+ v_redact = gr.Checkbox(label="Redact scenes (requires FFmpeg)", value=False)
168
+ v_categories = gr.CheckboxGroup(label="Categories", choices=ALL_CATEGORIES, value=ALL_CATEGORIES)
169
+ v_input = gr.Video(label="Upload short video (≤ 60s)")
170
+ v_btn = gr.Button("Analyze Video", variant="primary")
171
+ with gr.Column(scale=1, min_width=360):
172
+ v_segments = gr.Dataframe(label="Segments", interactive=False)
173
+ v_out = gr.Video(label="Redacted Video")
174
+
175
+ v_model.change(on_model_change, inputs=v_model, outputs=[v_categories, v_threshold])
176
+ v_btn.click(analyze_video, inputs=[v_model, v_input, v_categories, v_threshold, v_fps, v_redact],
177
+ outputs=[v_segments, v_out])
178
+
179
+ if __name__ == "__main__":
180
+ demo.launch()
model_registry.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple
2
+ from PIL import Image
3
+
4
+ ALL_CATEGORIES = [
5
+ "alcohol","drugs","weapons","gambling",
6
+ "nudity","sexy","smoking","violence"
7
+ ]
8
+ DEFAULT_THRESHOLD = 0.5
9
+ NUDENET_ONLY = {"clip-nudenet-lp", "siglip-nudenet-lp"}
10
+
11
+ class BaseModel:
12
+ name = "base"
13
+ supports_selected_tags = False
14
+ categories = ALL_CATEGORIES
15
+ def load(self): raise NotImplementedError
16
+ def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]:
17
+ raise NotImplementedError
18
+ def extra_selected_tags(self, pil_image: Image.Image, top_k: int = 10) -> List[Tuple[str, float]]:
19
+ return []
20
+
21
+ class Clip_MultiLabel(BaseModel):
22
+ name = "clip-multilabel"
23
+ categories = ALL_CATEGORIES
24
+ def __init__(self, head_path="weights/clip_multilabel.pt"):
25
+ self._cfg = dict(head_path=head_path, categories=self.categories)
26
+ self._m = None
27
+ def load(self):
28
+ from src.models import CLIPMultiLabel
29
+ if self._m is None:
30
+ self._m = CLIPMultiLabel(**self._cfg)
31
+ def predict_image(self, pil_image, requested_categories: List[str]) -> Dict[str, float]:
32
+ p = self._m.prob([pil_image])[0].tolist()
33
+ return {c: float(p[i]) for i, c in enumerate(self.categories) if c in requested_categories}
34
+
35
+ class _EVABaseAdapter(BaseModel):
36
+ supports_selected_tags = True
37
+ REPO_ID = ""
38
+ TAG_CSV = ""
39
+ def __init__(self, head_path: str):
40
+ self._cfg = dict(head_path=head_path, categories=self.categories)
41
+ self._m = None
42
+ def load(self):
43
+ from src.models.eva_headpreserving import EVAHeadPreserving
44
+ if self._m is None:
45
+ self._m = EVAHeadPreserving(repo_id=self.REPO_ID,
46
+ head_path=self._cfg["head_path"],
47
+ categories=self.categories,
48
+ tag_csv=self.TAG_CSV)
49
+ def predict_image(self, pil_image, requested_categories: List[str]) -> Dict[str, float]:
50
+ p = self._m.prob([pil_image])[0].tolist()
51
+ return {c: float(p[i]) for i, c in enumerate(self.categories) if c in requested_categories}
52
+ def extra_selected_tags(self, pil_image: Image.Image, top_k: int = 50) -> List[Tuple[str, float]]:
53
+ return self._m.top_tags(pil_image, top_k=top_k)
54
+
55
+ class WDEva02_Multitask(_EVABaseAdapter):
56
+ name = "wdeva02-multitask"
57
+ REPO_ID = "SmilingWolf/wd-eva02-large-tagger-v3"
58
+ TAG_CSV = "wdeva02_tags.csv"
59
+ def __init__(self, head_path="weights/wdeva02.pt"): super().__init__(head_path=head_path)
60
+
61
+ class Animetimm_Multitask(_EVABaseAdapter):
62
+ name = "animetimm-multitask"
63
+ REPO_ID = "animetimm/eva02_large_patch14_448.dbv4-full"
64
+ TAG_CSV = "animetimm_tags.csv"
65
+ def __init__(self, head_path="weights/animetimm.pt"): super().__init__(head_path=head_path)
66
+
67
+ class Clip_NudeNet_LP(BaseModel):
68
+ name = "clip-nudenet-lp"; categories = ["sexual"]
69
+ def __init__(self, head_path: str = "weights/clip_nudenet_lp.npz"):
70
+ self._cfg = dict(head_path=head_path); self._lp = None
71
+ def load(self):
72
+ if self._lp is None:
73
+ from src.models import CLIPLinearProbe
74
+ self._lp = CLIPLinearProbe(**self._cfg)
75
+ def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]:
76
+ return {"sexual": float(self._lp.prob([pil_image])[0])}
77
+
78
+ class Siglip_NudeNet_LP(BaseModel):
79
+ name = "siglip-nudenet-lp"; categories = ["sexual"]
80
+ def __init__(self, head_path: str = "weights/siglip_nudenet_lp.npz"):
81
+ self._cfg = dict(head_path=head_path); self._lp = None
82
+ def load(self):
83
+ if self._lp is None:
84
+ from src.models import SigLIPLinearProbe
85
+ self._lp = SigLIPLinearProbe(**self._cfg)
86
+ def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]:
87
+ return {"sexual": float(self._lp.prob([pil_image])[0])}
88
+
89
+ REGISTRY = {
90
+ "clip-multilabel": Clip_MultiLabel(),
91
+ "wdeva02-multilabel": WDEva02_Multitask(),
92
+ "animetimm-multilabel": Animetimm_Multitask(),
93
+ "clip-nudenet-lp": Clip_NudeNet_LP(),
94
+ "siglip-nudenet-lp": Siglip_NudeNet_LP(),
95
+ }
96
+
97
+ def get_model(name: str) -> BaseModel:
98
+ m = REGISTRY[name]
99
+ if not hasattr(m, "_loaded"):
100
+ m.load()
101
+ m._loaded = True
102
+ return m
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ open_clip_torch
3
+ transformers
4
+ huggingface_hub
5
+ Pillow
6
+ numpy
7
+ pandas
8
+ gradio>=4.44.0
9
+ opencv-python-headless
10
+ ffmpeg-python
11
+ timm
src/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .clip_lp import CLIPLinearProbe
2
+ from .siglip_lp import SigLIPLinearProbe
3
+ from .clip_multilabel import CLIPMultiLabel
4
+ from .eva_headpreserving import EVAHeadPreserving
src/models/animetimm_tags.csv ADDED
The diff for this file is too large to render. See raw diff
 
src/models/clip_lp.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import torch
4
+ import open_clip
5
+ from contextlib import nullcontext
6
+
7
+ from src.models.utils import l2norm_rows
8
+
9
+ class CLIPLinearProbe:
10
+ def __init__(self, head_path):
11
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
13
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
14
+ "ViT-L-14", pretrained="openai", device=self.device
15
+ )
16
+ self.model.eval().requires_grad_(False)
17
+ npz = np.load(head_path)
18
+ self.w = torch.from_numpy(npz["w"]).to(self.device).float()
19
+ self.b = torch.from_numpy(npz["b"]).to(self.device).float()
20
+
21
+ if self.device == "cuda":
22
+ torch.backends.cuda.matmul.allow_tf32 = True
23
+ torch.backends.cudnn.benchmark = True
24
+ self.use_amp = True
25
+
26
+ @torch.inference_mode()
27
+ def encode(self, pil_list) -> torch.Tensor:
28
+ x = torch.stack([self.preprocess(im.convert("RGB")) for im in pil_list], 0)
29
+ x = x.to(self.device, non_blocking=True, memory_format=torch.channels_last)
30
+ ctx = torch.amp.autocast("cuda", dtype=self.torch_dtype) if self.use_amp else nullcontext()
31
+ with ctx:
32
+ f = self.model.encode_image(x)
33
+ f = f.float()
34
+ return l2norm_rows(f)
35
+
36
+ @torch.inference_mode()
37
+ def logits(self, pil_list) -> torch.Tensor:
38
+ f = self.encode(pil_list)
39
+ return (f @ self.w + self.b).squeeze(1)
40
+
41
+ @torch.inference_mode()
42
+ def prob(self, pil_list) -> torch.Tensor:
43
+ z = torch.clamp(self.logits(pil_list), -50, 50)
44
+ return torch.sigmoid(z)
src/models/clip_multilabel.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import torch
3
+ import open_clip
4
+ from contextlib import nullcontext
5
+
6
+ from src.models.utils import l2norm_rows
7
+
8
+ class CLIPMultiLabel:
9
+ def __init__(self, head_path, categories):
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
12
+ self.categories = list(categories)
13
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
14
+ "ViT-L-14", pretrained="openai", device=self.device
15
+ )
16
+ self.model.eval().requires_grad_(False)
17
+
18
+ ckpt = torch.load(head_path, map_location=self.device, weights_only=True)
19
+ state = ckpt.get("model_state", ckpt)
20
+
21
+ w = state["head.weight"].to(self.device).float()
22
+ b = state["head.bias"].to(self.device).float()
23
+ w = w.t()
24
+
25
+ self.w, self.b = w, b
26
+
27
+ if self.device == "cuda":
28
+ torch.backends.cuda.matmul.allow_tf32 = True
29
+ torch.backends.cudnn.benchmark = True
30
+ self.use_amp = True
31
+
32
+ @torch.inference_mode()
33
+ def encode(self, pil_list) -> torch.Tensor:
34
+ x = torch.stack([self.preprocess(im.convert("RGB")) for im in pil_list], 0)
35
+ x = x.to(self.device, non_blocking=True, memory_format=torch.channels_last)
36
+ ctx = torch.amp.autocast("cuda", dtype=self.torch_dtype) if self.use_amp else nullcontext()
37
+ with ctx:
38
+ f = self.model.encode_image(x)
39
+ return l2norm_rows(f.float())
40
+
41
+ @torch.inference_mode()
42
+ def logits(self, pil_list) -> torch.Tensor:
43
+ f = self.encode(pil_list)
44
+ return f @ self.w + self.b
45
+
46
+ @torch.inference_mode()
47
+ def prob(self, pil_list) -> torch.Tensor:
48
+ z = torch.clamp(self.logits(pil_list), -50, 50)
49
+ return torch.sigmoid(z)
src/models/eva_headpreserving.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import List, Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import timm
7
+ from timm.data import resolve_model_data_config, create_transform
8
+ from contextlib import nullcontext
9
+
10
+ from .utils import load_tag_names
11
+
12
+ class EVAHeadPreserving:
13
+ """
14
+ Head-preserving inference for EVA-02 backbones (Animetimm / WD-EVA02).
15
+ Interface: encode / logits / prob / tags_prob / top_tags
16
+ """
17
+ def __init__(self,
18
+ repo_id: str,
19
+ head_path: str,
20
+ categories: List[str],
21
+ tag_csv: str = "selected_tags.csv"):
22
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
24
+ self.use_amp = (self.device == "cuda")
25
+
26
+ self.categories = list(categories)
27
+ self.tag_csv = tag_csv
28
+
29
+ self.backbone = timm.create_model(f"hf-hub:{repo_id}", pretrained=True)
30
+ self.backbone = self.backbone.to(self.device).eval().requires_grad_(False)
31
+
32
+ cfg = resolve_model_data_config(self.backbone)
33
+ self.preprocess = create_transform(**cfg)
34
+
35
+ with torch.no_grad():
36
+ in_size = cfg.get("input_size", (3, 448, 448))
37
+ h, w = int(in_size[-2]), int(in_size[-1])
38
+ dummy = torch.zeros(1, 3, h, w, device=self.device)
39
+ fx = self.backbone.forward_features(dummy)
40
+ pre = self.backbone.forward_head(fx, pre_logits=True)
41
+ tags_log = self.backbone.forward_head(fx, pre_logits=False)
42
+ D, T = int(pre.shape[-1]), int(tags_log.shape[-1])
43
+
44
+ self.custom_head = nn.Linear(D, len(self.categories)).to(self.device).eval().requires_grad_(False)
45
+
46
+ ckpt = torch.load(head_path, map_location=self.device, weights_only=True)
47
+ state = ckpt.get("state_dict", ckpt)
48
+
49
+ w = state["head.weight"].to(self.device).float()
50
+ b = state["head.bias"].to(self.device).float()
51
+ if w.shape != self.custom_head.weight.shape and w.t().shape == self.custom_head.weight.shape:
52
+ w = w.t()
53
+
54
+ with torch.no_grad():
55
+ self.custom_head.weight.copy_(w)
56
+ self.custom_head.bias.copy_(b)
57
+
58
+ self.tag_names = load_tag_names(T, self.tag_csv)
59
+
60
+ if self.device == "cuda":
61
+ torch.backends.cuda.matmul.allow_tf32 = True
62
+ torch.backends.cudnn.allow_tf32 = True
63
+ torch.backends.cudnn.benchmark = True
64
+
65
+ @torch.inference_mode()
66
+ def encode(self, pil_list: List) -> Tuple[torch.Tensor, torch.Tensor]:
67
+ x = torch.stack([self.preprocess(im.convert("RGB")) for im in pil_list], 0)
68
+ x = x.to(self.device, non_blocking=True, memory_format=torch.channels_last)
69
+ ctx = torch.amp.autocast("cuda", dtype=self.torch_dtype) if self.use_amp else nullcontext()
70
+ with ctx:
71
+ fx = self.backbone.forward_features(x)
72
+ pre = self.backbone.forward_head(fx, pre_logits=True)
73
+ feat = F.normalize(pre, dim=1)
74
+ tags_log = self.backbone.forward_head(fx, pre_logits=False)
75
+ return feat.float(), tags_log.float()
76
+
77
+ @torch.inference_mode()
78
+ def logits(self, pil_list: List) -> torch.Tensor:
79
+ feat_norm, _ = self.encode(pil_list)
80
+ return self.custom_head(feat_norm)
81
+
82
+ @torch.inference_mode()
83
+ def prob(self, pil_list: List) -> torch.Tensor:
84
+ z = torch.clamp(self.logits(pil_list), -20, 20)
85
+ return torch.sigmoid(z)
86
+
87
+ @torch.inference_mode()
88
+ def tags_prob(self, pil_list: List) -> torch.Tensor:
89
+ _, tags_log = self.encode(pil_list)
90
+ z = torch.clamp(tags_log, -20, 20)
91
+ return torch.sigmoid(z)
92
+
93
+ @torch.inference_mode()
94
+ def top_tags(self, pil_image, top_k: int = 50):
95
+ p = self.tags_prob([pil_image])[0].tolist()
96
+ k = max(0, min(top_k, len(p)))
97
+ idx = sorted(range(len(p)), key=lambda i: -p[i])[:k]
98
+ names = self.tag_names
99
+ return [(names[i] if i < len(names) else f"tag_{i:04d}", float(p[i])) for i in idx]
src/models/siglip_lp.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import torch
4
+ from contextlib import nullcontext
5
+ from transformers import AutoProcessor, SiglipModel
6
+
7
+ from src.models.utils import l2norm_rows
8
+
9
+ class SigLIPLinearProbe:
10
+ def __init__(self, head_path):
11
+ self.model_id = "google/siglip-so400m-patch14-384"
12
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
14
+ self.model = SiglipModel.from_pretrained(self.model_id, dtype=self.torch_dtype).to(self.device)
15
+ self.model.eval().requires_grad_(False)
16
+ self.processor = AutoProcessor.from_pretrained(self.model_id)
17
+
18
+ npz = np.load(head_path)
19
+ self.w = torch.from_numpy(npz["w"]).to(self.device).float()
20
+ self.b = torch.from_numpy(npz["b"]).to(self.device).float()
21
+
22
+ if self.device == "cuda":
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ torch.backends.cudnn.benchmark = True
25
+ self.use_amp = True
26
+
27
+ @torch.inference_mode()
28
+ def encode(self, pil_list) -> torch.Tensor:
29
+ imgs = [im.convert("RGB") for im in pil_list]
30
+ enc = self.processor(images=imgs, return_tensors="pt")
31
+ x = enc["pixel_values"].to(self.device, non_blocking=True, memory_format=torch.channels_last)
32
+ ctx = torch.amp.autocast("cuda", dtype=self.torch_dtype) if self.use_amp else nullcontext()
33
+ with ctx:
34
+ f = self.model.get_image_features(pixel_values=x)
35
+ f = f.float()
36
+ return l2norm_rows(f)
37
+
38
+ @torch.inference_mode()
39
+ def logits(self, pil_list) -> torch.Tensor:
40
+ f = self.encode(pil_list)
41
+ return (f @ self.w + self.b).squeeze(1)
42
+
43
+ @torch.inference_mode()
44
+ def prob(self, pil_list ) -> torch.Tensor:
45
+ z = torch.clamp(self.logits(pil_list), -50, 50)
46
+ return torch.sigmoid(z)
src/models/utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, csv, torch
2
+
3
+ def l2norm_rows(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
4
+ return x / (x.norm(dim=1, keepdim=True) + eps)
5
+
6
+ def load_tag_names(T: int, csv_name: str) -> list[str]:
7
+ p = os.path.join(os.path.dirname(__file__), csv_name)
8
+ names: list[str] = []
9
+ if os.path.isfile(p):
10
+ with open(p, "r", encoding="utf-8", newline="") as f:
11
+ for row in csv.reader(f):
12
+ if len(row) > 1 and row[1].strip():
13
+ names.append(row[1].strip())
14
+ if len(names) >= T:
15
+ return names[:T]
16
+ return names + [f"tag_{i:04d}" for i in range(len(names), T)]
src/models/wdeva02_tags.csv ADDED
The diff for this file is too large to render. See raw diff
 
video_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, shutil, subprocess
2
+
3
+ def has_ffmpeg() -> bool:
4
+ return shutil.which("ffmpeg") is not None and shutil.which("ffprobe") is not None
5
+
6
+ def probe_duration(video_path: str) -> float | None:
7
+ if not shutil.which("ffprobe"):
8
+ return None
9
+ try:
10
+ out = subprocess.check_output(
11
+ ["ffprobe","-v","error","-select_streams","v:0","-show_entries","stream=duration","-of","default=nw=1:nk=1",video_path],
12
+ stderr=subprocess.STDOUT, text=True
13
+ )
14
+ return float(out.strip())
15
+ except Exception:
16
+ return None
17
+
18
+ def extract_frames_ffmpeg(video_path: str, fps: float, out_dir: str) -> list[tuple[str, int]]:
19
+ os.makedirs(out_dir, exist_ok=True)
20
+ tpl = os.path.join(out_dir, "frame_%06d.jpg")
21
+ subprocess.check_call(
22
+ ["ffmpeg","-y","-i",video_path,"-vf",f"fps={fps}","-qscale:v","2",tpl],
23
+ stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
24
+ )
25
+ frames = sorted([os.path.join(out_dir,f) for f in os.listdir(out_dir) if f.lower().endswith(".jpg")])
26
+ return [(p, i) for i, p in enumerate(frames)]
27
+
28
+ def runs_from_indices(idxs: list[int]) -> list[tuple[int,int]]:
29
+ if not idxs: return []
30
+ idxs = sorted(idxs)
31
+ runs, s, prev = [], idxs[0], idxs[0]
32
+ for x in idxs[1:]:
33
+ if x == prev + 1:
34
+ prev = x
35
+ else:
36
+ runs.append((s, prev)); s = prev = x
37
+ runs.append((s, prev))
38
+ return runs
39
+
40
+ def merge_seconds_union(all_indices: list[int], fps: float, pad: float = 0.25) -> list[tuple[float,float]]:
41
+ if not all_indices: return []
42
+ runs = runs_from_indices(sorted(all_indices))
43
+ intervals = []
44
+ for a, b in runs:
45
+ start = max(0.0, a / fps - pad)
46
+ end = (b + 1) / fps + pad
47
+ intervals.append((start, end))
48
+ merged = []
49
+ for s, e in sorted(intervals):
50
+ if not merged or s > merged[-1][1]:
51
+ merged.append((s, e))
52
+ else:
53
+ merged[-1] = (merged[-1][0], max(merged[-1][1], e))
54
+ return merged
55
+
56
+ def redact_with_ffmpeg(video_path: str, intervals: list[tuple[float,float]], out_path: str):
57
+ if not intervals:
58
+ shutil.copyfile(video_path, out_path); return
59
+ parts = [f"between(t\\,{s:.3f}\\,{e:.3f})" for s, e in intervals]
60
+ expr = f"not({' + '.join(parts)})"
61
+ vf = f"select='{expr}',setpts=N/FRAME_RATE/TB"
62
+ subprocess.check_call(["ffmpeg","-y","-i",video_path,"-vf",vf,"-an",out_path],
63
+ stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)