content_moderation_demo / model_registry.py
onullusoy's picture
Upload 12 files
5e94db5 verified
raw
history blame
4.45 kB
from typing import Dict, List, Tuple
from PIL import Image
ALL_CATEGORIES = [
"alcohol","drugs","weapons","gambling",
"nudity","sexy","smoking","violence"
]
DEFAULT_THRESHOLD = 0.5
NUDENET_ONLY = {"clip-nudenet-lp", "siglip-nudenet-lp"}
class BaseModel:
name = "base"
supports_selected_tags = False
categories = ALL_CATEGORIES
def load(self): raise NotImplementedError
def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]:
raise NotImplementedError
def extra_selected_tags(self, pil_image: Image.Image, top_k: int = 10) -> List[Tuple[str, float]]:
return []
class Clip_MultiLabel(BaseModel):
name = "clip-multilabel"
categories = ALL_CATEGORIES
def __init__(self, head_path="weights/clip_multilabel.pt"):
self._cfg = dict(head_path=head_path, categories=self.categories)
self._m = None
def load(self):
from src.models import CLIPMultiLabel
if self._m is None:
self._m = CLIPMultiLabel(**self._cfg)
def predict_image(self, pil_image, requested_categories: List[str]) -> Dict[str, float]:
p = self._m.prob([pil_image])[0].tolist()
return {c: float(p[i]) for i, c in enumerate(self.categories) if c in requested_categories}
class _EVABaseAdapter(BaseModel):
supports_selected_tags = True
REPO_ID = ""
TAG_CSV = ""
def __init__(self, head_path: str):
self._cfg = dict(head_path=head_path, categories=self.categories)
self._m = None
def load(self):
from src.models.eva_headpreserving import EVAHeadPreserving
if self._m is None:
self._m = EVAHeadPreserving(repo_id=self.REPO_ID,
head_path=self._cfg["head_path"],
categories=self.categories,
tag_csv=self.TAG_CSV)
def predict_image(self, pil_image, requested_categories: List[str]) -> Dict[str, float]:
p = self._m.prob([pil_image])[0].tolist()
return {c: float(p[i]) for i, c in enumerate(self.categories) if c in requested_categories}
def extra_selected_tags(self, pil_image: Image.Image, top_k: int = 50) -> List[Tuple[str, float]]:
return self._m.top_tags(pil_image, top_k=top_k)
class WDEva02_Multitask(_EVABaseAdapter):
name = "wdeva02-multitask"
REPO_ID = "SmilingWolf/wd-eva02-large-tagger-v3"
TAG_CSV = "wdeva02_tags.csv"
def __init__(self, head_path="weights/wdeva02.pt"): super().__init__(head_path=head_path)
class Animetimm_Multitask(_EVABaseAdapter):
name = "animetimm-multitask"
REPO_ID = "animetimm/eva02_large_patch14_448.dbv4-full"
TAG_CSV = "animetimm_tags.csv"
def __init__(self, head_path="weights/animetimm.pt"): super().__init__(head_path=head_path)
class Clip_NudeNet_LP(BaseModel):
name = "clip-nudenet-lp"; categories = ["sexual"]
def __init__(self, head_path: str = "weights/clip_nudenet_lp.npz"):
self._cfg = dict(head_path=head_path); self._lp = None
def load(self):
if self._lp is None:
from src.models import CLIPLinearProbe
self._lp = CLIPLinearProbe(**self._cfg)
def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]:
return {"sexual": float(self._lp.prob([pil_image])[0])}
class Siglip_NudeNet_LP(BaseModel):
name = "siglip-nudenet-lp"; categories = ["sexual"]
def __init__(self, head_path: str = "weights/siglip_nudenet_lp.npz"):
self._cfg = dict(head_path=head_path); self._lp = None
def load(self):
if self._lp is None:
from src.models import SigLIPLinearProbe
self._lp = SigLIPLinearProbe(**self._cfg)
def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]:
return {"sexual": float(self._lp.prob([pil_image])[0])}
REGISTRY = {
"clip-multilabel": Clip_MultiLabel(),
"wdeva02-multilabel": WDEva02_Multitask(),
"animetimm-multilabel": Animetimm_Multitask(),
"clip-nudenet-lp": Clip_NudeNet_LP(),
"siglip-nudenet-lp": Siglip_NudeNet_LP(),
}
def get_model(name: str) -> BaseModel:
m = REGISTRY[name]
if not hasattr(m, "_loaded"):
m.load()
m._loaded = True
return m