File size: 4,447 Bytes
5e94db5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import Dict, List, Tuple
from PIL import Image

ALL_CATEGORIES = [
    "alcohol","drugs","weapons","gambling",
    "nudity","sexy","smoking","violence"
]
DEFAULT_THRESHOLD = 0.5
NUDENET_ONLY = {"clip-nudenet-lp", "siglip-nudenet-lp"}

class BaseModel:
    name = "base"
    supports_selected_tags = False
    categories = ALL_CATEGORIES
    def load(self): raise NotImplementedError
    def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]:
        raise NotImplementedError
    def extra_selected_tags(self, pil_image: Image.Image, top_k: int = 10) -> List[Tuple[str, float]]:
        return []

class Clip_MultiLabel(BaseModel):
    name = "clip-multilabel"
    categories = ALL_CATEGORIES
    def __init__(self, head_path="weights/clip_multilabel.pt"):
        self._cfg = dict(head_path=head_path, categories=self.categories)
        self._m = None
    def load(self):
        from src.models import CLIPMultiLabel
        if self._m is None:
            self._m = CLIPMultiLabel(**self._cfg)
    def predict_image(self, pil_image, requested_categories: List[str]) -> Dict[str, float]:
        p = self._m.prob([pil_image])[0].tolist()
        return {c: float(p[i]) for i, c in enumerate(self.categories) if c in requested_categories}

class _EVABaseAdapter(BaseModel):
    supports_selected_tags = True
    REPO_ID = ""
    TAG_CSV = ""
    def __init__(self, head_path: str):
        self._cfg = dict(head_path=head_path, categories=self.categories)
        self._m = None
    def load(self):
        from src.models.eva_headpreserving import EVAHeadPreserving
        if self._m is None:
            self._m = EVAHeadPreserving(repo_id=self.REPO_ID,
                                        head_path=self._cfg["head_path"],
                                        categories=self.categories,
                                        tag_csv=self.TAG_CSV)
    def predict_image(self, pil_image, requested_categories: List[str]) -> Dict[str, float]:
        p = self._m.prob([pil_image])[0].tolist()
        return {c: float(p[i]) for i, c in enumerate(self.categories) if c in requested_categories}
    def extra_selected_tags(self, pil_image: Image.Image, top_k: int = 50) -> List[Tuple[str, float]]:
        return self._m.top_tags(pil_image, top_k=top_k)

class WDEva02_Multitask(_EVABaseAdapter):
    name = "wdeva02-multitask"
    REPO_ID = "SmilingWolf/wd-eva02-large-tagger-v3"
    TAG_CSV = "wdeva02_tags.csv"
    def __init__(self, head_path="weights/wdeva02.pt"): super().__init__(head_path=head_path)

class Animetimm_Multitask(_EVABaseAdapter):
    name = "animetimm-multitask"
    REPO_ID = "animetimm/eva02_large_patch14_448.dbv4-full"
    TAG_CSV = "animetimm_tags.csv"
    def __init__(self, head_path="weights/animetimm.pt"): super().__init__(head_path=head_path)

class Clip_NudeNet_LP(BaseModel):
    name = "clip-nudenet-lp"; categories = ["sexual"]
    def __init__(self, head_path: str = "weights/clip_nudenet_lp.npz"):
        self._cfg = dict(head_path=head_path); self._lp = None
    def load(self):
        if self._lp is None:
            from src.models import CLIPLinearProbe
            self._lp = CLIPLinearProbe(**self._cfg)
    def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]:
        return {"sexual": float(self._lp.prob([pil_image])[0])}

class Siglip_NudeNet_LP(BaseModel):
    name = "siglip-nudenet-lp"; categories = ["sexual"]
    def __init__(self, head_path: str = "weights/siglip_nudenet_lp.npz"):
        self._cfg = dict(head_path=head_path); self._lp = None
    def load(self):
        if self._lp is None:
            from src.models import SigLIPLinearProbe
            self._lp = SigLIPLinearProbe(**self._cfg)
    def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]:
        return {"sexual": float(self._lp.prob([pil_image])[0])}

REGISTRY = {
    "clip-multilabel": Clip_MultiLabel(),
    "wdeva02-multilabel": WDEva02_Multitask(),
    "animetimm-multilabel": Animetimm_Multitask(),
    "clip-nudenet-lp": Clip_NudeNet_LP(),
    "siglip-nudenet-lp": Siglip_NudeNet_LP(),
}

def get_model(name: str) -> BaseModel:
    m = REGISTRY[name]
    if not hasattr(m, "_loaded"):
        m.load()
        m._loaded = True
    return m