Spaces:
Runtime error
Runtime error
File size: 11,473 Bytes
f179fb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 |
"""
Utility functions for Pixagram - Enhanced facial attributes
"""
import numpy as np
import cv2
import math
from PIL import Image, ImageEnhance, ImageFilter, ImageDraw
from config import COLOR_MATCH_CONFIG, FACE_MASK_CONFIG, AGE_BRACKETS
def sanitize_text(text):
"""Remove problematic characters"""
if not text:
return text
try:
text = text.encode('utf-8', errors='ignore').decode('utf-8')
text = ''.join(char for char in text if ord(char) < 65536)
except:
pass
return text
def get_facial_attributes(face):
"""
Extract comprehensive facial attributes including expression.
Returns dict with age, gender, expression, quality, pose.
"""
attributes = {
'age': None,
'gender': None,
'expression': None,
'quality': 1.0,
'pose_angle': 0,
'description': []
}
# Age
try:
if hasattr(face, 'age'):
age = int(face.age)
attributes['age'] = age
for min_age, max_age, label in AGE_BRACKETS:
if min_age <= age < max_age:
attributes['description'].append(label)
break
except:
pass
# Gender
try:
if hasattr(face, 'gender'):
gender_code = int(face.gender)
attributes['gender'] = gender_code
if gender_code == 1:
attributes['description'].append("male")
elif gender_code == 0:
attributes['description'].append("female")
except:
pass
# Expression (if available)
try:
if hasattr(face, 'emotion'):
emotion = face.emotion
if isinstance(emotion, (list, tuple)) and len(emotion) > 0:
emotions = ['neutral', 'happiness', 'surprise', 'sadness', 'anger', 'disgust', 'fear']
emotion_idx = int(np.argmax(emotion))
emotion_name = emotions[emotion_idx] if emotion_idx < len(emotions) else 'neutral'
confidence = float(emotion[emotion_idx])
if confidence > 0.4:
if emotion_name == 'happiness':
attributes['expression'] = 'smiling'
attributes['description'].append('smiling')
elif emotion_name not in ['neutral']:
attributes['expression'] = emotion_name
except:
pass
# Pose angle
try:
if hasattr(face, 'pose') and len(face.pose) > 1:
yaw = float(face.pose[1])
attributes['pose_angle'] = abs(yaw)
except:
pass
# Quality
try:
if hasattr(face, 'det_score'):
attributes['quality'] = float(face.det_score)
except:
pass
return attributes
def build_enhanced_prompt(base_prompt, facial_attributes, trigger_word):
"""Build enhanced prompt with facial attributes"""
descriptions = facial_attributes['description']
if not descriptions:
return base_prompt
prompt_lower = base_prompt.lower()
has_demographics = any(desc.lower() in prompt_lower for desc in descriptions)
if not has_demographics:
demographic_str = ", ".join(descriptions) + " person"
prompt = base_prompt.replace(trigger_word, f"{trigger_word}, {demographic_str}", 1)
age = facial_attributes.get('age')
quality = facial_attributes.get('quality')
expression = facial_attributes.get('expression')
print(f"[FACE] Detected: {', '.join(descriptions)}")
print(f" Age: {age if age else 'N/A'}, Quality: {quality:.2f}")
if expression:
print(f" Expression: {expression}")
return prompt
return base_prompt
def get_demographic_description(age, gender_code):
"""Legacy function - kept for compatibility"""
demo_desc = []
if age is not None:
try:
age_int = int(age)
for min_age, max_age, label in AGE_BRACKETS:
if min_age <= age_int < max_age:
demo_desc.append(label)
break
except:
pass
if gender_code is not None:
try:
if int(gender_code) == 1:
demo_desc.append("male")
elif int(gender_code) == 0:
demo_desc.append("female")
except:
pass
return demo_desc
def color_match_lab(target, source, preserve_saturation=True):
"""LAB color matching"""
try:
target_lab = cv2.cvtColor(target.astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
source_lab = cv2.cvtColor(source.astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
result_lab = np.copy(target_lab)
t_mean, t_std = target_lab[:,:,0].mean(), target_lab[:,:,0].std()
s_mean, s_std = source_lab[:,:,0].mean(), source_lab[:,:,0].std()
if t_std > 1e-6:
matched = (target_lab[:,:,0] - t_mean) * (s_std / t_std) * 0.5 + s_mean
result_lab[:,:,0] = target_lab[:,:,0] * (1 - COLOR_MATCH_CONFIG['lab_lightness_blend']) + matched * COLOR_MATCH_CONFIG['lab_lightness_blend']
if preserve_saturation:
for i in [1, 2]:
t_mean, t_std = target_lab[:,:,i].mean(), target_lab[:,:,i].std()
s_mean, s_std = source_lab[:,:,i].mean(), source_lab[:,:,i].std()
if t_std > 1e-6:
matched = (target_lab[:,:,i] - t_mean) * (s_std / t_std) + s_mean
blend_factor = COLOR_MATCH_CONFIG['lab_color_blend_preserved']
result_lab[:,:,i] = target_lab[:,:,i] * (1 - blend_factor) + matched * blend_factor
else:
for i in [1, 2]:
t_mean, t_std = target_lab[:,:,i].mean(), target_lab[:,:,i].std()
s_mean, s_std = source_lab[:,:,i].mean(), source_lab[:,:,i].std()
if t_std > 1e-6:
matched = (target_lab[:,:,i] - t_mean) * (s_std / t_std) + s_mean
blend_factor = COLOR_MATCH_CONFIG['lab_color_blend_full']
result_lab[:,:,i] = target_lab[:,:,i] * (1 - blend_factor) + matched * blend_factor
return cv2.cvtColor(result_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
except:
return target.astype(np.uint8)
def enhanced_color_match(target_img, source_img, face_bbox=None, preserve_vibrance=False):
"""Enhanced color matching with face awareness"""
try:
target = np.array(target_img).astype(np.float32)
source = np.array(source_img).astype(np.float32)
if face_bbox is not None:
x1, y1, x2, y2 = [int(c) for c in face_bbox]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(target.shape[1], x2), min(target.shape[0], y2)
face_mask = np.zeros((target.shape[0], target.shape[1]), dtype=np.float32)
face_mask[y1:y2, x1:x2] = 1.0
face_mask = cv2.GaussianBlur(face_mask, COLOR_MATCH_CONFIG['gaussian_blur_kernel'], COLOR_MATCH_CONFIG['gaussian_blur_sigma'])
face_mask = face_mask[:, :, np.newaxis]
if y2 > y1 and x2 > x1:
face_result = color_match_lab(target[y1:y2, x1:x2], source[y1:y2, x1:x2], preserve_saturation=True)
target[y1:y2, x1:x2] = face_result
result = target * face_mask + target * (1 - face_mask)
else:
result = color_match_lab(target, source, preserve_saturation=True)
else:
result = color_match_lab(target, source, preserve_saturation=True)
result_img = Image.fromarray(result.astype(np.uint8))
return result_img
except:
return target_img
def color_match(target_img, source_img, mode='mkl'):
"""Legacy color matching"""
try:
target = np.array(target_img).astype(np.float32)
source = np.array(source_img).astype(np.float32)
if mode == 'mkl':
result = color_match_lab(target, source)
else:
result = np.zeros_like(target)
for i in range(3):
t_mean, t_std = target[:,:,i].mean(), target[:,:,i].std()
s_mean, s_std = source[:,:,i].mean(), source[:,:,i].std()
result[:,:,i] = (target[:,:,i] - t_mean) * (s_std / (t_std + 1e-6)) + s_mean
result[:,:,i] = np.clip(result[:,:,i], 0, 255)
return Image.fromarray(result.astype(np.uint8))
except:
return target_img
def create_face_mask(image, face_bbox, feather=None):
"""Create soft face mask"""
if feather is None:
feather = FACE_MASK_CONFIG['feather']
mask = Image.new('L', image.size, 0)
draw = ImageDraw.Draw(mask)
x1, y1, x2, y2 = face_bbox
padding = int((x2 - x1) * FACE_MASK_CONFIG['padding'])
x1 = max(0, x1 - padding)
y1 = max(0, y1 - padding)
x2 = min(image.width, x2 + padding)
y2 = min(image.height, y2 + padding)
draw.ellipse([x1, y1, x2, y2], fill=255)
mask = mask.filter(ImageFilter.GaussianBlur(feather))
return mask
def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
"""Draw facial keypoints"""
stickwidth = 4
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
kps = np.array(kps)
w, h = image_pil.size
out_img = np.zeros([h, w, 3])
for i in range(len(limbSeq)):
index = limbSeq[i]
color = color_list[index[0]]
x = kps[index][:, 0]
y = kps[index][:, 1]
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
out_img = (out_img * 0.6).astype(np.uint8)
for idx_kp, kp in enumerate(kps):
color = color_list[idx_kp]
x, y = kp
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
return Image.fromarray(out_img.astype(np.uint8))
def calculate_optimal_size(original_width, original_height, recommended_sizes):
"""Calculate optimal size"""
aspect_ratio = original_width / original_height
best_match = None
best_diff = float('inf')
for width, height in recommended_sizes:
rec_aspect = width / height
diff = abs(rec_aspect - aspect_ratio)
if diff < best_diff:
best_diff = diff
best_match = (width, height)
width, height = best_match
width = int((width // 8) * 8)
height = int((height // 8) * 8)
return width, height
def enhance_face_crop(face_crop):
"""Multi-stage face enhancement"""
face_crop_resized = face_crop.resize((224, 224), Image.LANCZOS)
enhancer = ImageEnhance.Sharpness(face_crop_resized)
face_crop_sharp = enhancer.enhance(1.5)
enhancer = ImageEnhance.Contrast(face_crop_sharp)
face_crop_enhanced = enhancer.enhance(1.1)
enhancer = ImageEnhance.Brightness(face_crop_enhanced)
face_crop_final = enhancer.enhance(1.05)
return face_crop_final
print("[OK] Utils loaded (Enhanced facial attributes)")
|