Spaces:
Runtime error
Runtime error
| """ | |
| Generation logic for Pixagram - Torch 2.1.1 + Depth Anything V2 optimized | |
| """ | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from config import * | |
| from utils import * | |
| from models import * | |
| class RetroArtConverter: | |
| """Main retro art generator with torch 2.1.1 optimizations""" | |
| def __init__(self): | |
| self.device = device | |
| self.dtype = dtype | |
| self.models_loaded = { | |
| 'custom_checkpoint': False, | |
| 'lora': False, | |
| 'instantid': False, | |
| 'depth_detector': False, | |
| 'ip_adapter': False | |
| } | |
| # Face analysis with CPU fallback | |
| self.face_app, self.face_detection_enabled = load_face_analysis() | |
| # Depth detector with Depth Anything V2 priority | |
| self.depth_detector, depth_success, self.depth_type = load_depth_detector() | |
| self.models_loaded['depth_detector'] = depth_success | |
| print(f"[DEPTH] Using: {self.depth_type}") | |
| # ControlNets | |
| controlnet_depth, self.controlnet_instantid, instantid_success = load_controlnets() | |
| self.controlnet_depth = controlnet_depth | |
| self.instantid_enabled = instantid_success | |
| self.models_loaded['instantid'] = instantid_success | |
| # Image encoder | |
| if self.instantid_enabled: | |
| self.image_encoder = load_image_encoder() | |
| else: | |
| self.image_encoder = None | |
| # Determine controlnets | |
| if self.instantid_enabled and self.controlnet_instantid is not None: | |
| controlnets = [self.controlnet_instantid, controlnet_depth] | |
| else: | |
| controlnets = controlnet_depth | |
| # SDXL pipeline | |
| self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets) | |
| self.models_loaded['custom_checkpoint'] = checkpoint_success | |
| # LORA | |
| lora_success = load_lora(self.pipe) | |
| self.models_loaded['lora'] = lora_success | |
| # IP-Adapter | |
| if self.instantid_enabled and self.image_encoder is not None: | |
| self.image_proj_model, ip_adapter_success = setup_ip_adapter(self.pipe, self.image_encoder) | |
| self.models_loaded['ip_adapter'] = ip_adapter_success | |
| else: | |
| self.models_loaded['ip_adapter'] = False | |
| self.image_proj_model = None | |
| # Compel | |
| self.compel, self.use_compel = setup_compel(self.pipe) | |
| # LCM scheduler | |
| setup_scheduler(self.pipe) | |
| # TORCH 2.1.1: Apply optimizations (compile, etc.) | |
| optimize_pipeline(self.pipe) | |
| # Caption model | |
| self.caption_processor, self.caption_model, self.caption_enabled = load_caption_model() | |
| # CLIP skip | |
| set_clip_skip(self.pipe) | |
| self.using_multiple_controlnets = isinstance(controlnets, list) | |
| self._print_status() | |
| print(" [OK] Initialization complete") | |
| def _print_status(self): | |
| """Print model status""" | |
| print("\n=== MODEL STATUS ===") | |
| for model, loaded in self.models_loaded.items(): | |
| status = "[OK]" if loaded else "[FALLBACK]" | |
| print(f"{model}: {status}") | |
| print("====================\n") | |
| def get_depth_map(self, image): | |
| """Generate depth map with Depth Anything V2 or fallback""" | |
| if self.depth_type == "depth_anything_v2" and self.depth_detector is not None: | |
| try: | |
| result = self.depth_detector(image) | |
| depth_image = result["depth"] | |
| # Convert to PIL if needed | |
| if not isinstance(depth_image, Image.Image): | |
| depth_array = np.array(depth_image) | |
| depth_image = Image.fromarray(depth_array) | |
| return depth_image | |
| except Exception as e: | |
| print(f"[WARNING] Depth Anything V2 failed: {e}, using fallback") | |
| if self.depth_type == "zoe" and self.depth_detector is not None: | |
| try: | |
| depth_image = self.depth_detector(image) | |
| return depth_image | |
| except Exception as e: | |
| print(f"[WARNING] Zoe failed: {e}, using grayscale") | |
| # Grayscale fallback | |
| gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) | |
| depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) | |
| return Image.fromarray(depth_colored) | |
| def add_trigger_word(self, prompt): | |
| """Add trigger word if not present""" | |
| if TRIGGER_WORD.lower() not in prompt.lower(): | |
| return f"{TRIGGER_WORD}, {prompt}" | |
| return prompt | |
| def extract_multi_scale_face(self, face_crop, face): | |
| """Multi-scale face extraction""" | |
| try: | |
| multi_scale_embeds = [] | |
| for scale in MULTI_SCALE_FACTORS: | |
| w, h = face_crop.size | |
| scaled_size = (int(w * scale), int(h * scale)) | |
| scaled_crop = face_crop.resize(scaled_size, Image.LANCZOS) | |
| scaled_crop = scaled_crop.resize((w, h), Image.LANCZOS) | |
| scaled_array = cv2.cvtColor(np.array(scaled_crop), cv2.COLOR_RGB2BGR) | |
| scaled_faces = self.face_app.get(scaled_array) | |
| if len(scaled_faces) > 0: | |
| multi_scale_embeds.append(scaled_faces[0].normed_embedding) | |
| if len(multi_scale_embeds) > 0: | |
| averaged = np.mean(multi_scale_embeds, axis=0) | |
| averaged = averaged / np.linalg.norm(averaged) | |
| return averaged | |
| return face.normed_embedding | |
| except Exception as e: | |
| return face.normed_embedding | |
| def detect_face_quality(self, face): | |
| """Adaptive parameter adjustment""" | |
| try: | |
| bbox = face.bbox | |
| face_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) | |
| det_score = float(face.det_score) if hasattr(face, 'det_score') else 1.0 | |
| if face_size < ADAPTIVE_THRESHOLDS['small_face_size']: | |
| return ADAPTIVE_PARAMS['small_face'].copy() | |
| elif det_score < ADAPTIVE_THRESHOLDS['low_confidence']: | |
| return ADAPTIVE_PARAMS['low_confidence'].copy() | |
| elif hasattr(face, 'pose') and len(face.pose) > 1: | |
| try: | |
| yaw = float(face.pose[1]) | |
| if abs(yaw) > ADAPTIVE_THRESHOLDS['profile_angle']: | |
| return ADAPTIVE_PARAMS['profile_view'].copy() | |
| except: | |
| pass | |
| return None | |
| except: | |
| return None | |
| def validate_and_adjust_parameters(self, strength, guidance_scale, lora_scale, | |
| identity_preservation, identity_control_scale, | |
| depth_control_scale, consistency_mode=True): | |
| """Parameter validation""" | |
| if consistency_mode: | |
| adjustments = [] | |
| if identity_preservation > 1.2: | |
| original_lora = lora_scale | |
| lora_scale = min(lora_scale, 1.0) | |
| if abs(lora_scale - original_lora) > 0.01: | |
| adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f}") | |
| if strength < 0.5: | |
| if identity_preservation < 1.3: | |
| identity_preservation = 1.3 | |
| if lora_scale > 0.9: | |
| lora_scale = 0.9 | |
| elif strength > 0.7: | |
| if identity_preservation > 1.0: | |
| identity_preservation = 1.0 | |
| if lora_scale < 1.2: | |
| lora_scale = 1.2 | |
| original_cfg = guidance_scale | |
| guidance_scale = max(1.0, min(guidance_scale, 1.5)) | |
| if adjustments: | |
| print(" [OK] Applied adjustments") | |
| return strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale | |
| def generate_caption(self, image, max_length=None, num_beams=None): | |
| """Generate caption""" | |
| if not self.caption_enabled or self.caption_model is None: | |
| return None | |
| if max_length is None: | |
| max_length = CAPTION_CONFIG['max_length'] | |
| if num_beams is None: | |
| num_beams = CAPTION_CONFIG['num_beams'] | |
| try: | |
| inputs = self.caption_processor(image, return_tensors="pt").to(self.device, self.dtype) | |
| with torch.no_grad(): | |
| output = self.caption_model.generate(**inputs, max_length=max_length, num_beams=num_beams) | |
| caption = self.caption_processor.decode(output[0], skip_special_tokens=True) | |
| return caption | |
| except Exception as e: | |
| return None | |
| def generate_retro_art( | |
| self, | |
| input_image, | |
| prompt="retro game character", | |
| negative_prompt="blurry, low quality", | |
| num_inference_steps=12, | |
| guidance_scale=1.0, | |
| depth_control_scale=0.8, | |
| identity_control_scale=0.85, | |
| lora_scale=1.0, | |
| identity_preservation=0.8, | |
| strength=0.75, | |
| enable_color_matching=False, | |
| consistency_mode=True, | |
| seed=-1 | |
| ): | |
| """Generate retro art with torch 2.1.1 optimizations""" | |
| prompt = sanitize_text(prompt) | |
| negative_prompt = sanitize_text(negative_prompt) | |
| if consistency_mode: | |
| strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale = \ | |
| self.validate_and_adjust_parameters( | |
| strength, guidance_scale, lora_scale, identity_preservation, | |
| identity_control_scale, depth_control_scale, consistency_mode | |
| ) | |
| prompt = self.add_trigger_word(prompt) | |
| original_width, original_height = input_image.size | |
| target_width, target_height = calculate_optimal_size(original_width, original_height, RECOMMENDED_SIZES) | |
| resized_image = input_image.resize((int(target_width), int(target_height)), Image.LANCZOS) | |
| print("Generating depth map...") | |
| depth_image = self.get_depth_map(resized_image) | |
| if depth_image.size != (target_width, target_height): | |
| depth_image = depth_image.resize((int(target_width), int(target_height)), Image.LANCZOS) | |
| using_multiple_controlnets = self.using_multiple_controlnets | |
| face_kps_image = None | |
| face_embeddings = None | |
| face_crop_enhanced = None | |
| has_detected_faces = False | |
| face_bbox_original = None | |
| if using_multiple_controlnets and self.face_app is not None: | |
| print("Detecting faces...") | |
| img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR) | |
| faces = self.face_app.get(img_array) | |
| if len(faces) > 0: | |
| has_detected_faces = True | |
| face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1] | |
| adaptive_params = self.detect_face_quality(face) | |
| if adaptive_params is not None: | |
| print(f"[ADAPTIVE] {adaptive_params['reason']}") | |
| identity_preservation = adaptive_params['identity_preservation'] | |
| identity_control_scale = adaptive_params['identity_control_scale'] | |
| guidance_scale = adaptive_params['guidance_scale'] | |
| lora_scale = adaptive_params['lora_scale'] | |
| face_embeddings_base = face.normed_embedding | |
| bbox = face.bbox.astype(int) | |
| x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3] | |
| face_bbox_original = [x1, y1, x2, y2] | |
| face_width = x2 - x1 | |
| face_height = y2 - y1 | |
| padding_x = int(face_width * 0.3) | |
| padding_y = int(face_height * 0.3) | |
| x1 = max(0, x1 - padding_x) | |
| y1 = max(0, y1 - padding_y) | |
| x2 = min(resized_image.width, x2 + padding_x) | |
| y2 = min(resized_image.height, y2 + padding_y) | |
| face_crop = resized_image.crop((x1, y1, x2, y2)) | |
| face_embeddings = self.extract_multi_scale_face(face_crop, face) | |
| face_crop_enhanced = enhance_face_crop(face_crop) | |
| face_kps = face.kps | |
| face_kps_image = draw_kps(resized_image, face_kps) | |
| # ENHANCED: Use new facial attributes extraction | |
| facial_attrs = get_facial_attributes(face) | |
| prompt = build_enhanced_prompt(prompt, facial_attrs, TRIGGER_WORD) | |
| if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']: | |
| try: | |
| self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale]) | |
| except: | |
| pass | |
| pipe_kwargs = { | |
| "image": resized_image, | |
| "strength": strength, | |
| "num_inference_steps": num_inference_steps, | |
| "guidance_scale": guidance_scale, | |
| } | |
| if seed == -1: | |
| generator = torch.Generator(device=self.device) | |
| actual_seed = generator.seed() | |
| else: | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| actual_seed = seed | |
| pipe_kwargs["generator"] = generator | |
| if self.use_compel and self.compel is not None: | |
| try: | |
| conditioning = self.compel(prompt) | |
| negative_conditioning = self.compel(negative_prompt) | |
| pipe_kwargs["prompt_embeds"] = conditioning[0] | |
| pipe_kwargs["pooled_prompt_embeds"] = conditioning[1] | |
| pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0] | |
| pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1] | |
| except: | |
| pipe_kwargs["prompt"] = prompt | |
| pipe_kwargs["negative_prompt"] = negative_prompt | |
| else: | |
| pipe_kwargs["prompt"] = prompt | |
| pipe_kwargs["negative_prompt"] = negative_prompt | |
| if hasattr(self.pipe, 'text_encoder'): | |
| pipe_kwargs["clip_skip"] = 2 | |
| if using_multiple_controlnets and has_detected_faces and face_kps_image is not None: | |
| control_images = [face_kps_image, depth_image] | |
| conditioning_scales = [identity_control_scale, depth_control_scale] | |
| pipe_kwargs["control_image"] = control_images | |
| pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales | |
| if face_embeddings is not None and self.models_loaded.get('ip_adapter', False) and face_crop_enhanced is not None: | |
| with torch.no_grad(): | |
| insightface_embeds = torch.from_numpy(face_embeddings).to( | |
| device=self.device, dtype=self.dtype | |
| ).unsqueeze(0).unsqueeze(1) | |
| image_embeds = self.image_proj_model(insightface_embeds) | |
| boosted_scale = identity_preservation * IDENTITY_BOOST_MULTIPLIER | |
| pipe_kwargs["added_cond_kwargs"] = {"image_embeds": image_embeds, "time_ids": None} | |
| pipe_kwargs["cross_attention_kwargs"] = {"ip_adapter_scale": boosted_scale} | |
| else: | |
| if using_multiple_controlnets and not has_detected_faces: | |
| control_images = [depth_image, depth_image] | |
| conditioning_scales = [0.0, depth_control_scale] | |
| pipe_kwargs["control_image"] = control_images | |
| pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales | |
| else: | |
| pipe_kwargs["control_image"] = depth_image | |
| pipe_kwargs["controlnet_conditioning_scale"] = depth_control_scale | |
| if self.models_loaded.get('ip_adapter', False): | |
| dummy_embeds = torch.zeros( | |
| (1, 4, self.pipe.unet.config.cross_attention_dim), | |
| device=self.device, dtype=self.dtype | |
| ) | |
| pipe_kwargs["added_cond_kwargs"] = {"image_embeds": dummy_embeds, "time_ids": None} | |
| pipe_kwargs["cross_attention_kwargs"] = {"ip_adapter_scale": 0.0} | |
| # TORCH 2.1.1: Use optimized attention backend | |
| print(f"Generating (steps={num_inference_steps}, cfg={guidance_scale}, strength={strength})...") | |
| if device == "cuda" and hasattr(torch.backends.cuda, 'sdp_kernel'): | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=True, | |
| enable_mem_efficient=True, | |
| enable_math=False | |
| ): | |
| result = self.pipe(**pipe_kwargs) | |
| else: | |
| result = self.pipe(**pipe_kwargs) | |
| generated_image = result.images[0] | |
| if enable_color_matching and has_detected_faces: | |
| try: | |
| if face_bbox_original is not None: | |
| generated_image = enhanced_color_match(generated_image, resized_image, face_bbox=face_bbox_original) | |
| else: | |
| generated_image = color_match(generated_image, resized_image, mode='mkl') | |
| except: | |
| pass | |
| elif enable_color_matching: | |
| try: | |
| generated_image = color_match(generated_image, resized_image, mode='mkl') | |
| except: | |
| pass | |
| return generated_image | |
| print("[OK] Generator ready (Torch 2.1.1 + Depth Anything V2)") | |