# hf_backend.py import time, logging from contextlib import nullcontext from typing import Any, Dict, AsyncIterable, Tuple import torch from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from backends_base import ChatBackend, ImagesBackend from config import settings logger = logging.getLogger(__name__) try: import spaces from spaces.zero import client as zero_client except ImportError: spaces, zero_client = None, None # --- Model setup --- MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct" logger.info(f"Preloading tokenizer for {MODEL_ID} on CPU...") tokenizer, load_error = None, None try: tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True, use_fast=False, ) except Exception as e: load_error = f"Failed to load tokenizer: {e}" logger.exception(load_error) # ---------------- helpers ---------------- def _pick_cpu_dtype() -> torch.dtype: if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"): try: if torch.cpu.is_bf16_supported(): logger.info("CPU BF16 supported, will attempt torch.bfloat16") return torch.bfloat16 except Exception: pass logger.info("Falling back to torch.float32 on CPU") return torch.float32 # ---------------- global cache ---------------- _MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {} def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]: key = (device, dtype) if key in _MODEL_CACHE: return _MODEL_CACHE[key], dtype cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) if hasattr(cfg, "quantization_config"): logger.warning("Removing quantization_config from model config") delattr(cfg, "quantization_config") eff_dtype = dtype try: model = AutoModelForCausalLM.from_pretrained( MODEL_ID, config=cfg, torch_dtype=dtype, trust_remote_code=True, device_map="auto" if device != "cpu" else {"": "cpu"}, low_cpu_mem_usage=False, # ensure full load before casting ) except Exception as e: if device == "cpu" and dtype == torch.bfloat16: logger.warning(f"BF16 load failed on CPU: {e}. Retrying with FP32.") eff_dtype = torch.float32 model = AutoModelForCausalLM.from_pretrained( MODEL_ID, config=cfg, torch_dtype=eff_dtype, trust_remote_code=True, device_map={"": "cpu"}, low_cpu_mem_usage=False, ) else: raise # --- Force recast to target dtype/device (fixes FP8 leftovers) --- model = model.to(device=device, dtype=eff_dtype) model.eval() _MODEL_CACHE[(device, eff_dtype)] = model return model, eff_dtype # ---------------- Chat Backend ---------------- class HFChatBackend(ChatBackend): async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]: if load_error: raise RuntimeError(load_error) messages = request.get("messages", []) temperature = float(request.get("temperature", settings.LlmTemp or 0.7)) max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512)) rid = f"chatcmpl-hf-{int(time.time())}" now = int(time.time()) # --- Inject X-IP-Token into global headers if ZeroGPU is used --- x_ip_token = request.get("x_ip_token") if x_ip_token and zero_client: zero_client.HEADERS["X-IP-Token"] = x_ip_token logger.debug("Injected X-IP-Token into ZeroGPU headers") # Build prompt using chat template if available if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None): try: prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) logger.debug("Applied chat template for prompt") except Exception as e: logger.warning(f"Failed to apply chat template: {e}, using fallback") prompt = messages[-1]["content"] if messages else "(empty)" else: prompt = messages[-1]["content"] if messages else "(empty)" def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str: model, eff_dtype = _get_model(device, req_dtype) inputs = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()} with torch.inference_mode(): if device != "cpu": autocast_ctx = torch.autocast(device_type=device, dtype=eff_dtype) else: if eff_dtype == torch.bfloat16: autocast_ctx = torch.cpu.amp.autocast(dtype=torch.bfloat16) else: autocast_ctx = nullcontext() with autocast_ctx: outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True, use_cache=True, ) return tokenizer.decode(outputs[0], skip_special_tokens=True) if spaces: @spaces.GPU(duration=120) def run_once(prompt: str) -> str: if torch.cuda.is_available(): return _run_once(prompt, device="cuda", req_dtype=torch.float16) return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype()) text = run_once(prompt) else: text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype()) yield { "id": rid, "object": "chat.completion.chunk", "created": now, "model": MODEL_ID, "choices": [ {"index": 0, "delta": {"content": text}, "finish_reason": "stop"} ], } # ---------------- Stub Images Backend ---------------- class StubImagesBackend(ImagesBackend): """ Stub backend for images since HFChatBackend is text-only. Returns a transparent 1x1 PNG placeholder. """ async def generate_b64(self, request: Dict[str, Any]) -> str: logger.warning("Image generation not supported in HF backend.") return ( "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII=" )