import torch.nn as nn import torchvision from scipy.spatial import Delaunay import torch import numpy as np from torch.nn import functional as nnf from easydict import EasyDict from shapely.geometry import Point from shapely.geometry.polygon import Polygon from diffusers import StableDiffusionPipeline from accelerate import init_empty_weights, load_checkpoint_in_model class SDSLoss(nn.Module): def __init__(self, cfg, device): super(SDSLoss, self).__init__() self.cfg = cfg self.device = device self.fp16 = device.type == "cuda" dtype = torch.float16 if self.fp16 else torch.float32 dm = "balanced" if torch.cuda.device_count() > 1 else None self.pipe = StableDiffusionPipeline.from_pretrained( cfg.diffusion.model, torch_dtype=dtype, token=cfg.token, device_map=None, safety_checker=None, # remove NSFW checker (≈ 750 MB) requires_safety_checker=False, ) self.pipe.to(device) self.pipe.text_encoder.to(device) if self.fp16: try: self.pipe.enable_xformers_memory_efficient_attention() except Exception as e: print(f"WARNING: xFormers memory efficient attention could not be enabled: {e}") # Use additional VRAM to disable memory-saving features for speed if hasattr(self.pipe, "disable_attention_slicing"): self.pipe.disable_attention_slicing() if hasattr(self.pipe, "disable_vae_slicing"): self.pipe.disable_vae_slicing() if hasattr(self.pipe, "disable_vae_tiling"): self.pipe.disable_vae_tiling() # Gradient checkpointing trades memory for speed; disable it if self.fp16 and hasattr(self.pipe.unet, "disable_gradient_checkpointing"): self.pipe.unet.disable_gradient_checkpointing() alphas_cumprod = torch.tensor(self.pipe.scheduler.alphas_cumprod) self.alphas = alphas_cumprod.to(device) self.sigmas = torch.sqrt(1 - self.alphas) # 1️⃣ embed text while all weights are still real tensors self.embed_text() # 2️⃣ NOW turn on off-loading (only UNet & VAE get meta tensors) #self.pipe.enable_model_cpu_offload() # text-encoder is no longer needed #del self.pipe.text_encoder, self.pipe.tokenizer def embed_text(self): tok = self.pipe.tokenizer txt = tok( self.cfg.caption, padding="max_length", max_length=tok.model_max_length, truncation=True, return_tensors="pt", ) un = tok( [""], padding="max_length", max_length=tok.model_max_length, return_tensors="pt", ) with torch.no_grad(): te = self.pipe.text_encoder.eval() # --- NEW: ensure both id tensors live on self.device ------------ ids_c = txt.input_ids.to(self.device) ids_u = un.input_ids.to(self.device) em_txt = te(ids_c).last_hidden_state.to(torch.float32) em_un = te(ids_u).last_hidden_state.to(torch.float32) # ---------------------------------------------------------------- self.text_embeddings = ( torch.cat([em_un, em_txt]) .repeat_interleave(self.cfg.batch_size, 0) .to(self.device) ) def forward(self, x_aug: torch.Tensor) -> torch.Tensor: # ---------------------------------------------------- encode x = (x_aug * 2.0 - 1.0).to(self.device, dtype=torch.float32) if self.fp16: with torch.cuda.amp.autocast(): latents = self.pipe.vae.encode(x).latent_dist.sample() else: latents = self.pipe.vae.encode(x).latent_dist.sample() latents = 0.18215 * latents torch.cuda.empty_cache() # ---------------------------------------------------- add noise t = torch.randint( 50, min(950, self.cfg.diffusion.timesteps) - 1, (latents.size(0),), device=self.device, ) eps = torch.randn_like(latents) z_t = self.pipe.scheduler.add_noise(latents, eps, t) # ---------------------------------------------------- sequential CFG emb_u, emb_c = self.text_embeddings.chunk(2) with torch.cuda.amp.autocast(): eps_u = self.pipe.unet(z_t, t, encoder_hidden_states=emb_u).sample torch.cuda.empty_cache() # release ~500 MB with torch.cuda.amp.autocast(): eps_c = self.pipe.unet(z_t, t, encoder_hidden_states=emb_c).sample # UNet already ran in fp16 under autocast – avoid duplicating tensors eps_t = eps_u + self.cfg.diffusion.guidance_scale * (eps_c - eps_u) # ---------------------------------------------------- SDS grad & loss alpha_t = self.alphas[t].to(self.device) sigma_t = self.sigmas[t].to(self.device) grad = (alpha_t**0.5 * sigma_t * (eps_t - eps)).nan_to_num_() return (grad * latents).sum(1).mean() class ToneLoss(nn.Module): def __init__(self, cfg): super(ToneLoss, self).__init__() self.dist_loss_weight = cfg.loss.tone.dist_loss_weight self.im_init = None self.cfg = cfg self.mse_loss = nn.MSELoss() self.blurrer = torchvision.transforms.GaussianBlur(kernel_size=(cfg.loss.tone.pixel_dist_kernel_blur, cfg.loss.tone.pixel_dist_kernel_blur), sigma=(cfg.loss.tone.pixel_dist_sigma)) def set_image_init(self, im_init): self.im_init = im_init.permute(2, 0, 1).unsqueeze(0) self.init_blurred = self.blurrer(self.im_init) def get_scheduler(self, step=None): if step is not None: return self.dist_loss_weight * np.exp(-(1/5)*((step-300)/(20)) ** 2) else: return self.dist_loss_weight def forward(self, cur_raster, step=None): blurred_cur = self.blurrer(cur_raster) return self.mse_loss(self.init_blurred.detach(), blurred_cur) * self.get_scheduler(step) class ConformalLoss: def __init__(self, parameters: EasyDict, device: torch.device, target_letter: str, shape_groups): self.parameters = parameters self.target_letter = target_letter self.shape_groups = shape_groups self.faces = self.init_faces(device) self.faces_roll_a = [torch.roll(self.faces[i], 1, 1) for i in range(len(self.faces))] with torch.no_grad(): self.angles = [] self.reset() def get_angles(self, points: torch.Tensor) -> torch.Tensor: angles_ = [] for i in range(len(self.faces)): triangles = points[self.faces[i]] triangles_roll_a = points[self.faces_roll_a[i]] edges = triangles_roll_a - triangles length = edges.norm(dim=-1) edges = edges / (length + 1e-1)[:, :, None] edges_roll = torch.roll(edges, 1, 1) cosine = torch.einsum('ned,ned->ne', edges, edges_roll) angles = torch.arccos(cosine) angles_.append(angles) return angles_ def get_letter_inds(self, letter_to_insert): for group, l in zip(self.shape_groups, self.target_letter): if l == letter_to_insert: letter_inds = group.shape_ids return letter_inds[0], letter_inds[-1], len(letter_inds) def reset(self): points = torch.cat([point.clone().detach() for point in self.parameters.point]).to(self.faces[0].device) self.angles = self.get_angles(points) def init_faces(self, device: torch.device) -> torch.tensor: faces_ = [] for j, c in enumerate(self.target_letter): points_np = [self.parameters.point[i].clone().detach().cpu().numpy() for i in range(len(self.parameters.point))] start_ind, end_ind, shapes_per_letter = self.get_letter_inds(c) print(c, start_ind, end_ind) holes = [] if shapes_per_letter > 1: holes = points_np[start_ind+1:end_ind] poly = Polygon(points_np[start_ind], holes=holes) poly = poly.buffer(0) points_np = np.concatenate(points_np) faces = Delaunay(points_np).simplices is_intersect = np.array([poly.contains(Point(points_np[face].mean(0))) for face in faces], dtype=bool) faces_.append(torch.from_numpy(faces[is_intersect]).to(device, dtype=torch.int64)) return faces_ def __call__(self) -> torch.Tensor: loss_angles = 0 points = torch.cat(self.parameters.point).to(self.faces[0].device) angles = self.get_angles(points) for i in range(len(self.faces)): loss_angles += (nnf.mse_loss(angles[i], self.angles[i])) return loss_angles