pixagram-dev / generator.py
primerz's picture
Upload 11 files
f179fb3 verified
raw
history blame
18 kB
"""
Generation logic for Pixagram - Torch 2.1.1 + Depth Anything V2 optimized
"""
import torch
import numpy as np
import cv2
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
from config import *
from utils import *
from models import *
class RetroArtConverter:
"""Main retro art generator with torch 2.1.1 optimizations"""
def __init__(self):
self.device = device
self.dtype = dtype
self.models_loaded = {
'custom_checkpoint': False,
'lora': False,
'instantid': False,
'depth_detector': False,
'ip_adapter': False
}
# Face analysis with CPU fallback
self.face_app, self.face_detection_enabled = load_face_analysis()
# Depth detector with Depth Anything V2 priority
self.depth_detector, depth_success, self.depth_type = load_depth_detector()
self.models_loaded['depth_detector'] = depth_success
print(f"[DEPTH] Using: {self.depth_type}")
# ControlNets
controlnet_depth, self.controlnet_instantid, instantid_success = load_controlnets()
self.controlnet_depth = controlnet_depth
self.instantid_enabled = instantid_success
self.models_loaded['instantid'] = instantid_success
# Image encoder
if self.instantid_enabled:
self.image_encoder = load_image_encoder()
else:
self.image_encoder = None
# Determine controlnets
if self.instantid_enabled and self.controlnet_instantid is not None:
controlnets = [self.controlnet_instantid, controlnet_depth]
else:
controlnets = controlnet_depth
# SDXL pipeline
self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets)
self.models_loaded['custom_checkpoint'] = checkpoint_success
# LORA
lora_success = load_lora(self.pipe)
self.models_loaded['lora'] = lora_success
# IP-Adapter
if self.instantid_enabled and self.image_encoder is not None:
self.image_proj_model, ip_adapter_success = setup_ip_adapter(self.pipe, self.image_encoder)
self.models_loaded['ip_adapter'] = ip_adapter_success
else:
self.models_loaded['ip_adapter'] = False
self.image_proj_model = None
# Compel
self.compel, self.use_compel = setup_compel(self.pipe)
# LCM scheduler
setup_scheduler(self.pipe)
# TORCH 2.1.1: Apply optimizations (compile, etc.)
optimize_pipeline(self.pipe)
# Caption model
self.caption_processor, self.caption_model, self.caption_enabled = load_caption_model()
# CLIP skip
set_clip_skip(self.pipe)
self.using_multiple_controlnets = isinstance(controlnets, list)
self._print_status()
print(" [OK] Initialization complete")
def _print_status(self):
"""Print model status"""
print("\n=== MODEL STATUS ===")
for model, loaded in self.models_loaded.items():
status = "[OK]" if loaded else "[FALLBACK]"
print(f"{model}: {status}")
print("====================\n")
def get_depth_map(self, image):
"""Generate depth map with Depth Anything V2 or fallback"""
if self.depth_type == "depth_anything_v2" and self.depth_detector is not None:
try:
result = self.depth_detector(image)
depth_image = result["depth"]
# Convert to PIL if needed
if not isinstance(depth_image, Image.Image):
depth_array = np.array(depth_image)
depth_image = Image.fromarray(depth_array)
return depth_image
except Exception as e:
print(f"[WARNING] Depth Anything V2 failed: {e}, using fallback")
if self.depth_type == "zoe" and self.depth_detector is not None:
try:
depth_image = self.depth_detector(image)
return depth_image
except Exception as e:
print(f"[WARNING] Zoe failed: {e}, using grayscale")
# Grayscale fallback
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
return Image.fromarray(depth_colored)
def add_trigger_word(self, prompt):
"""Add trigger word if not present"""
if TRIGGER_WORD.lower() not in prompt.lower():
return f"{TRIGGER_WORD}, {prompt}"
return prompt
def extract_multi_scale_face(self, face_crop, face):
"""Multi-scale face extraction"""
try:
multi_scale_embeds = []
for scale in MULTI_SCALE_FACTORS:
w, h = face_crop.size
scaled_size = (int(w * scale), int(h * scale))
scaled_crop = face_crop.resize(scaled_size, Image.LANCZOS)
scaled_crop = scaled_crop.resize((w, h), Image.LANCZOS)
scaled_array = cv2.cvtColor(np.array(scaled_crop), cv2.COLOR_RGB2BGR)
scaled_faces = self.face_app.get(scaled_array)
if len(scaled_faces) > 0:
multi_scale_embeds.append(scaled_faces[0].normed_embedding)
if len(multi_scale_embeds) > 0:
averaged = np.mean(multi_scale_embeds, axis=0)
averaged = averaged / np.linalg.norm(averaged)
return averaged
return face.normed_embedding
except Exception as e:
return face.normed_embedding
def detect_face_quality(self, face):
"""Adaptive parameter adjustment"""
try:
bbox = face.bbox
face_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
det_score = float(face.det_score) if hasattr(face, 'det_score') else 1.0
if face_size < ADAPTIVE_THRESHOLDS['small_face_size']:
return ADAPTIVE_PARAMS['small_face'].copy()
elif det_score < ADAPTIVE_THRESHOLDS['low_confidence']:
return ADAPTIVE_PARAMS['low_confidence'].copy()
elif hasattr(face, 'pose') and len(face.pose) > 1:
try:
yaw = float(face.pose[1])
if abs(yaw) > ADAPTIVE_THRESHOLDS['profile_angle']:
return ADAPTIVE_PARAMS['profile_view'].copy()
except:
pass
return None
except:
return None
def validate_and_adjust_parameters(self, strength, guidance_scale, lora_scale,
identity_preservation, identity_control_scale,
depth_control_scale, consistency_mode=True):
"""Parameter validation"""
if consistency_mode:
adjustments = []
if identity_preservation > 1.2:
original_lora = lora_scale
lora_scale = min(lora_scale, 1.0)
if abs(lora_scale - original_lora) > 0.01:
adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f}")
if strength < 0.5:
if identity_preservation < 1.3:
identity_preservation = 1.3
if lora_scale > 0.9:
lora_scale = 0.9
elif strength > 0.7:
if identity_preservation > 1.0:
identity_preservation = 1.0
if lora_scale < 1.2:
lora_scale = 1.2
original_cfg = guidance_scale
guidance_scale = max(1.0, min(guidance_scale, 1.5))
if adjustments:
print(" [OK] Applied adjustments")
return strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale
def generate_caption(self, image, max_length=None, num_beams=None):
"""Generate caption"""
if not self.caption_enabled or self.caption_model is None:
return None
if max_length is None:
max_length = CAPTION_CONFIG['max_length']
if num_beams is None:
num_beams = CAPTION_CONFIG['num_beams']
try:
inputs = self.caption_processor(image, return_tensors="pt").to(self.device, self.dtype)
with torch.no_grad():
output = self.caption_model.generate(**inputs, max_length=max_length, num_beams=num_beams)
caption = self.caption_processor.decode(output[0], skip_special_tokens=True)
return caption
except Exception as e:
return None
def generate_retro_art(
self,
input_image,
prompt="retro game character",
negative_prompt="blurry, low quality",
num_inference_steps=12,
guidance_scale=1.0,
depth_control_scale=0.8,
identity_control_scale=0.85,
lora_scale=1.0,
identity_preservation=0.8,
strength=0.75,
enable_color_matching=False,
consistency_mode=True,
seed=-1
):
"""Generate retro art with torch 2.1.1 optimizations"""
prompt = sanitize_text(prompt)
negative_prompt = sanitize_text(negative_prompt)
if consistency_mode:
strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale = \
self.validate_and_adjust_parameters(
strength, guidance_scale, lora_scale, identity_preservation,
identity_control_scale, depth_control_scale, consistency_mode
)
prompt = self.add_trigger_word(prompt)
original_width, original_height = input_image.size
target_width, target_height = calculate_optimal_size(original_width, original_height, RECOMMENDED_SIZES)
resized_image = input_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
print("Generating depth map...")
depth_image = self.get_depth_map(resized_image)
if depth_image.size != (target_width, target_height):
depth_image = depth_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
using_multiple_controlnets = self.using_multiple_controlnets
face_kps_image = None
face_embeddings = None
face_crop_enhanced = None
has_detected_faces = False
face_bbox_original = None
if using_multiple_controlnets and self.face_app is not None:
print("Detecting faces...")
img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
faces = self.face_app.get(img_array)
if len(faces) > 0:
has_detected_faces = True
face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
adaptive_params = self.detect_face_quality(face)
if adaptive_params is not None:
print(f"[ADAPTIVE] {adaptive_params['reason']}")
identity_preservation = adaptive_params['identity_preservation']
identity_control_scale = adaptive_params['identity_control_scale']
guidance_scale = adaptive_params['guidance_scale']
lora_scale = adaptive_params['lora_scale']
face_embeddings_base = face.normed_embedding
bbox = face.bbox.astype(int)
x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
face_bbox_original = [x1, y1, x2, y2]
face_width = x2 - x1
face_height = y2 - y1
padding_x = int(face_width * 0.3)
padding_y = int(face_height * 0.3)
x1 = max(0, x1 - padding_x)
y1 = max(0, y1 - padding_y)
x2 = min(resized_image.width, x2 + padding_x)
y2 = min(resized_image.height, y2 + padding_y)
face_crop = resized_image.crop((x1, y1, x2, y2))
face_embeddings = self.extract_multi_scale_face(face_crop, face)
face_crop_enhanced = enhance_face_crop(face_crop)
face_kps = face.kps
face_kps_image = draw_kps(resized_image, face_kps)
# ENHANCED: Use new facial attributes extraction
facial_attrs = get_facial_attributes(face)
prompt = build_enhanced_prompt(prompt, facial_attrs, TRIGGER_WORD)
if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
try:
self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
except:
pass
pipe_kwargs = {
"image": resized_image,
"strength": strength,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
}
if seed == -1:
generator = torch.Generator(device=self.device)
actual_seed = generator.seed()
else:
generator = torch.Generator(device=self.device).manual_seed(seed)
actual_seed = seed
pipe_kwargs["generator"] = generator
if self.use_compel and self.compel is not None:
try:
conditioning = self.compel(prompt)
negative_conditioning = self.compel(negative_prompt)
pipe_kwargs["prompt_embeds"] = conditioning[0]
pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0]
pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1]
except:
pipe_kwargs["prompt"] = prompt
pipe_kwargs["negative_prompt"] = negative_prompt
else:
pipe_kwargs["prompt"] = prompt
pipe_kwargs["negative_prompt"] = negative_prompt
if hasattr(self.pipe, 'text_encoder'):
pipe_kwargs["clip_skip"] = 2
if using_multiple_controlnets and has_detected_faces and face_kps_image is not None:
control_images = [face_kps_image, depth_image]
conditioning_scales = [identity_control_scale, depth_control_scale]
pipe_kwargs["control_image"] = control_images
pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
if face_embeddings is not None and self.models_loaded.get('ip_adapter', False) and face_crop_enhanced is not None:
with torch.no_grad():
insightface_embeds = torch.from_numpy(face_embeddings).to(
device=self.device, dtype=self.dtype
).unsqueeze(0).unsqueeze(1)
image_embeds = self.image_proj_model(insightface_embeds)
boosted_scale = identity_preservation * IDENTITY_BOOST_MULTIPLIER
pipe_kwargs["added_cond_kwargs"] = {"image_embeds": image_embeds, "time_ids": None}
pipe_kwargs["cross_attention_kwargs"] = {"ip_adapter_scale": boosted_scale}
else:
if using_multiple_controlnets and not has_detected_faces:
control_images = [depth_image, depth_image]
conditioning_scales = [0.0, depth_control_scale]
pipe_kwargs["control_image"] = control_images
pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
else:
pipe_kwargs["control_image"] = depth_image
pipe_kwargs["controlnet_conditioning_scale"] = depth_control_scale
if self.models_loaded.get('ip_adapter', False):
dummy_embeds = torch.zeros(
(1, 4, self.pipe.unet.config.cross_attention_dim),
device=self.device, dtype=self.dtype
)
pipe_kwargs["added_cond_kwargs"] = {"image_embeds": dummy_embeds, "time_ids": None}
pipe_kwargs["cross_attention_kwargs"] = {"ip_adapter_scale": 0.0}
# TORCH 2.1.1: Use optimized attention backend
print(f"Generating (steps={num_inference_steps}, cfg={guidance_scale}, strength={strength})...")
if device == "cuda" and hasattr(torch.backends.cuda, 'sdp_kernel'):
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_mem_efficient=True,
enable_math=False
):
result = self.pipe(**pipe_kwargs)
else:
result = self.pipe(**pipe_kwargs)
generated_image = result.images[0]
if enable_color_matching and has_detected_faces:
try:
if face_bbox_original is not None:
generated_image = enhanced_color_match(generated_image, resized_image, face_bbox=face_bbox_original)
else:
generated_image = color_match(generated_image, resized_image, mode='mkl')
except:
pass
elif enable_color_matching:
try:
generated_image = color_match(generated_image, resized_image, mode='mkl')
except:
pass
return generated_image
print("[OK] Generator ready (Torch 2.1.1 + Depth Anything V2)")