from __future__ import annotations from typing import List, Tuple import torch import torch.nn as nn import torch.nn.functional as F import timm from timm.data import resolve_model_data_config, create_transform from contextlib import nullcontext from .utils import load_tag_names class EVAHeadPreserving: """ Head-preserving inference for EVA-02 backbones (Animetimm / WD-EVA02). Interface: encode / logits / prob / tags_prob / top_tags """ def __init__(self, repo_id: str, head_path: str, categories: List[str], tag_csv: str = "selected_tags.csv"): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 self.categories = list(categories) self.tag_csv = tag_csv self.backbone = timm.create_model(f"hf-hub:{repo_id}", pretrained=True) self.backbone = self.backbone.to(self.device).eval().requires_grad_(False) cfg = resolve_model_data_config(self.backbone) self.preprocess = create_transform(**cfg) with torch.no_grad(): in_size = cfg.get("input_size", (3, 448, 448)) h, w = int(in_size[-2]), int(in_size[-1]) dummy = torch.zeros(1, 3, h, w, device=self.device) fx = self.backbone.forward_features(dummy) pre = self.backbone.forward_head(fx, pre_logits=True) tags_log = self.backbone.forward_head(fx, pre_logits=False) D, T = int(pre.shape[-1]), int(tags_log.shape[-1]) self.custom_head = nn.Linear(D, len(self.categories)).to(self.device).eval().requires_grad_(False) ckpt = torch.load(head_path, map_location=self.device, weights_only=True) state = ckpt.get("state_dict", ckpt) w = state["head.weight"].to(self.device).float() b = state["head.bias"].to(self.device).float() if w.shape != self.custom_head.weight.shape and w.t().shape == self.custom_head.weight.shape: w = w.t() with torch.no_grad(): self.custom_head.weight.copy_(w) self.custom_head.bias.copy_(b) self.use_amp = True self.tag_names = load_tag_names(T, self.tag_csv) self.use_amp = False if self.device == "cuda": torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True @torch.inference_mode() def encode(self, pil_list: List) -> Tuple[torch.Tensor, torch.Tensor]: x = torch.stack([self.preprocess(im.convert("RGB")) for im in pil_list], 0) x = x.to(self.device, non_blocking=True, memory_format=torch.channels_last) ctx = torch.amp.autocast("cuda", dtype=self.torch_dtype) if self.use_amp else nullcontext() with ctx: fx = self.backbone.forward_features(x) pre = self.backbone.forward_head(fx, pre_logits=True) feat = F.normalize(pre, dim=1) tags_log = self.backbone.forward_head(fx, pre_logits=False) return feat.float(), tags_log.float() @torch.inference_mode() def logits(self, pil_list: List) -> torch.Tensor: feat_norm, _ = self.encode(pil_list) return self.custom_head(feat_norm) @torch.inference_mode() def prob(self, pil_list: List) -> torch.Tensor: z = torch.clamp(self.logits(pil_list), -20, 20) return torch.sigmoid(z) @torch.inference_mode() def tags_prob(self, pil_list: List) -> torch.Tensor: _, tags_log = self.encode(pil_list) z = torch.clamp(tags_log, -20, 20) return torch.sigmoid(z) @torch.inference_mode() def top_tags(self, pil_image, top_k: int = 50): p = self.tags_prob([pil_image])[0].tolist() k = max(0, min(top_k, len(p))) idx = sorted(range(len(p)), key=lambda i: -p[i])[:k] names = self.tag_names return [(names[i] if i < len(names) else f"tag_{i:04d}", float(p[i])) for i in idx]