Spaces:
Sleeping
Sleeping
| from typing import Dict, List, Tuple | |
| from PIL import Image | |
| ALL_CATEGORIES = [ | |
| "alcohol","drugs","weapons","gambling", | |
| "nudity","sexy","smoking","violence" | |
| ] | |
| DEFAULT_THRESHOLD = 0.5 | |
| NUDENET_ONLY = {"clip-nudenet-lp", "siglip-nudenet-lp"} | |
| class BaseModel: | |
| name = "base" | |
| supports_selected_tags = False | |
| categories = ALL_CATEGORIES | |
| def load(self): raise NotImplementedError | |
| def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]: | |
| raise NotImplementedError | |
| def extra_selected_tags(self, pil_image: Image.Image, top_k: int = 10) -> List[Tuple[str, float]]: | |
| return [] | |
| class Clip_MultiLabel(BaseModel): | |
| name = "clip-multilabel" | |
| categories = ALL_CATEGORIES | |
| def __init__(self, head_path="weights/clip_multilabel.pt"): | |
| self._cfg = dict(head_path=head_path, categories=self.categories) | |
| self._m = None | |
| def load(self): | |
| from src.models import CLIPMultiLabel | |
| if self._m is None: | |
| self._m = CLIPMultiLabel(**self._cfg) | |
| def predict_image(self, pil_image, requested_categories: List[str]) -> Dict[str, float]: | |
| p = self._m.prob([pil_image])[0].tolist() | |
| return {c: float(p[i]) for i, c in enumerate(self.categories) if c in requested_categories} | |
| class _EVABaseAdapter(BaseModel): | |
| supports_selected_tags = True | |
| REPO_ID = "" | |
| TAG_CSV = "" | |
| def __init__(self, head_path: str): | |
| self._cfg = dict(head_path=head_path, categories=self.categories) | |
| self._m = None | |
| def load(self): | |
| from src.models.eva_headpreserving import EVAHeadPreserving | |
| if self._m is None: | |
| self._m = EVAHeadPreserving(repo_id=self.REPO_ID, | |
| head_path=self._cfg["head_path"], | |
| categories=self.categories, | |
| tag_csv=self.TAG_CSV) | |
| def predict_image(self, pil_image, requested_categories: List[str]) -> Dict[str, float]: | |
| p = self._m.prob([pil_image])[0].tolist() | |
| return {c: float(p[i]) for i, c in enumerate(self.categories) if c in requested_categories} | |
| def extra_selected_tags(self, pil_image: Image.Image, top_k: int = 50) -> List[Tuple[str, float]]: | |
| return self._m.top_tags(pil_image, top_k=top_k) | |
| class WDEva02_Multitask(_EVABaseAdapter): | |
| name = "wdeva02-multitask" | |
| REPO_ID = "SmilingWolf/wd-eva02-large-tagger-v3" | |
| TAG_CSV = "wdeva02_tags.csv" | |
| def __init__(self, head_path="weights/wdeva02.pt"): super().__init__(head_path=head_path) | |
| class Animetimm_Multitask(_EVABaseAdapter): | |
| name = "animetimm-multitask" | |
| REPO_ID = "animetimm/eva02_large_patch14_448.dbv4-full" | |
| TAG_CSV = "animetimm_tags.csv" | |
| def __init__(self, head_path="weights/animetimm.pt"): super().__init__(head_path=head_path) | |
| class Clip_NudeNet_LP(BaseModel): | |
| name = "clip-nudenet-lp"; categories = ["sexual"] | |
| def __init__(self, head_path: str = "weights/clip_nudenet_lp.npz"): | |
| self._cfg = dict(head_path=head_path); self._lp = None | |
| def load(self): | |
| if self._lp is None: | |
| from src.models import CLIPLinearProbe | |
| self._lp = CLIPLinearProbe(**self._cfg) | |
| def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]: | |
| return {"sexual": float(self._lp.prob([pil_image])[0])} | |
| class Siglip_NudeNet_LP(BaseModel): | |
| name = "siglip-nudenet-lp"; categories = ["sexual"] | |
| def __init__(self, head_path: str = "weights/siglip_nudenet_lp.npz"): | |
| self._cfg = dict(head_path=head_path); self._lp = None | |
| def load(self): | |
| if self._lp is None: | |
| from src.models import SigLIPLinearProbe | |
| self._lp = SigLIPLinearProbe(**self._cfg) | |
| def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]: | |
| return {"sexual": float(self._lp.prob([pil_image])[0])} | |
| REGISTRY = { | |
| "clip-multilabel": Clip_MultiLabel(), | |
| "wdeva02-multilabel": WDEva02_Multitask(), | |
| "animetimm-multilabel": Animetimm_Multitask(), | |
| "clip-nudenet-lp": Clip_NudeNet_LP(), | |
| "siglip-nudenet-lp": Siglip_NudeNet_LP(), | |
| } | |
| def get_model(name: str) -> BaseModel: | |
| m = REGISTRY[name] | |
| if not hasattr(m, "_loaded"): | |
| m.load() | |
| m._loaded = True | |
| return m | |