|
|
import torch.nn as nn
|
|
|
import torchvision
|
|
|
from scipy.spatial import Delaunay
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from torch.nn import functional as nnf
|
|
|
from easydict import EasyDict
|
|
|
from shapely.geometry import Point
|
|
|
from shapely.geometry.polygon import Polygon
|
|
|
|
|
|
from diffusers import StableDiffusionPipeline
|
|
|
|
|
|
class SDSLoss(nn.Module):
|
|
|
def __init__(self, cfg, device):
|
|
|
super(SDSLoss, self).__init__()
|
|
|
self.cfg = cfg
|
|
|
self.device = device
|
|
|
self.fp16 = device.type == "cuda"
|
|
|
dtype = torch.float16 if self.fp16 else torch.float32
|
|
|
self.pipe = StableDiffusionPipeline.from_pretrained(
|
|
|
cfg.diffusion.model,
|
|
|
torch_dtype=dtype,
|
|
|
token=cfg.token,
|
|
|
).to(device)
|
|
|
|
|
|
self.pipe = StableDiffusionPipeline.from_pretrained(
|
|
|
cfg.diffusion.model,
|
|
|
torch_dtype=torch.float32,
|
|
|
token=cfg.token,
|
|
|
).to("cpu")
|
|
|
|
|
|
if self.fp16:
|
|
|
|
|
|
self.pipe.enable_attention_slicing(slice_size=1)
|
|
|
self.pipe.enable_vae_slicing()
|
|
|
self.pipe.enable_vae_tiling()
|
|
|
self.pipe.unet.enable_gradient_checkpointing()
|
|
|
|
|
|
alphas_cumprod = torch.tensor(self.pipe.scheduler.alphas_cumprod)
|
|
|
self.alphas = alphas_cumprod.to(device)
|
|
|
self.sigmas = torch.sqrt(1 - self.alphas)
|
|
|
|
|
|
|
|
|
self.embed_text()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def embed_text(self):
|
|
|
tok = self.pipe.tokenizer
|
|
|
txt = tok(self.cfg.caption, padding="max_length",
|
|
|
max_length=tok.model_max_length, truncation=True,
|
|
|
return_tensors="pt")
|
|
|
un = tok([""], padding="max_length",
|
|
|
max_length=tok.model_max_length, return_tensors="pt")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
te = self.pipe.text_encoder.eval()
|
|
|
em_txt = te(txt.input_ids).last_hidden_state.to(torch.float32)
|
|
|
em_un = te(un .input_ids).last_hidden_state.to(torch.float32)
|
|
|
|
|
|
self.text_embeddings = (
|
|
|
torch.cat([em_un, em_txt])
|
|
|
.repeat_interleave(self.cfg.batch_size, 0)
|
|
|
.to(self.device)
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x_aug: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
x = (x_aug * 2.0 - 1.0).to(self.device, dtype=torch.float32)
|
|
|
if self.fp16:
|
|
|
with torch.cuda.amp.autocast():
|
|
|
latents = self.pipe.vae.encode(x).latent_dist.sample()
|
|
|
else:
|
|
|
latents = self.pipe.vae.encode(x).latent_dist.sample()
|
|
|
|
|
|
latents = 0.18215 * latents
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
t = torch.randint(
|
|
|
50,
|
|
|
min(950, self.cfg.diffusion.timesteps) - 1,
|
|
|
(latents.size(0),),
|
|
|
device=self.device,
|
|
|
)
|
|
|
eps = torch.randn_like(latents)
|
|
|
z_t = self.pipe.scheduler.add_noise(latents, eps, t)
|
|
|
|
|
|
|
|
|
emb_u, emb_c = self.text_embeddings.chunk(2)
|
|
|
with torch.cuda.amp.autocast():
|
|
|
eps_u = self.pipe.unet(z_t, t, encoder_hidden_states=emb_u).sample
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
with torch.cuda.amp.autocast():
|
|
|
eps_c = self.pipe.unet(z_t, t, encoder_hidden_states=emb_c).sample
|
|
|
|
|
|
|
|
|
eps_t = eps_u + self.cfg.diffusion.guidance_scale * (eps_c - eps_u)
|
|
|
|
|
|
|
|
|
alpha_t = self.alphas[t].to(self.device)
|
|
|
sigma_t = self.sigmas[t].to(self.device)
|
|
|
grad = (alpha_t**0.5 * sigma_t * (eps_t - eps)).nan_to_num_()
|
|
|
return (grad * latents).sum(1).mean()
|
|
|
|
|
|
|
|
|
|
|
|
class ToneLoss(nn.Module):
|
|
|
def __init__(self, cfg):
|
|
|
super(ToneLoss, self).__init__()
|
|
|
self.dist_loss_weight = cfg.loss.tone.dist_loss_weight
|
|
|
self.im_init = None
|
|
|
self.cfg = cfg
|
|
|
self.mse_loss = nn.MSELoss()
|
|
|
self.blurrer = torchvision.transforms.GaussianBlur(kernel_size=(cfg.loss.tone.pixel_dist_kernel_blur,
|
|
|
cfg.loss.tone.pixel_dist_kernel_blur), sigma=(cfg.loss.tone.pixel_dist_sigma))
|
|
|
|
|
|
def set_image_init(self, im_init):
|
|
|
self.im_init = im_init.permute(2, 0, 1).unsqueeze(0)
|
|
|
self.init_blurred = self.blurrer(self.im_init)
|
|
|
|
|
|
|
|
|
def get_scheduler(self, step=None):
|
|
|
if step is not None:
|
|
|
return self.dist_loss_weight * np.exp(-(1/5)*((step-300)/(20)) ** 2)
|
|
|
else:
|
|
|
return self.dist_loss_weight
|
|
|
|
|
|
def forward(self, cur_raster, step=None):
|
|
|
blurred_cur = self.blurrer(cur_raster)
|
|
|
return self.mse_loss(self.init_blurred.detach(), blurred_cur) * self.get_scheduler(step)
|
|
|
|
|
|
|
|
|
class ConformalLoss:
|
|
|
def __init__(self, parameters: EasyDict, device: torch.device, target_letter: str, shape_groups):
|
|
|
self.parameters = parameters
|
|
|
self.target_letter = target_letter
|
|
|
self.shape_groups = shape_groups
|
|
|
self.faces = self.init_faces(device)
|
|
|
self.faces_roll_a = [torch.roll(self.faces[i], 1, 1) for i in range(len(self.faces))]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
self.angles = []
|
|
|
self.reset()
|
|
|
|
|
|
|
|
|
def get_angles(self, points: torch.Tensor) -> torch.Tensor:
|
|
|
angles_ = []
|
|
|
for i in range(len(self.faces)):
|
|
|
triangles = points[self.faces[i]]
|
|
|
triangles_roll_a = points[self.faces_roll_a[i]]
|
|
|
edges = triangles_roll_a - triangles
|
|
|
length = edges.norm(dim=-1)
|
|
|
edges = edges / (length + 1e-1)[:, :, None]
|
|
|
edges_roll = torch.roll(edges, 1, 1)
|
|
|
cosine = torch.einsum('ned,ned->ne', edges, edges_roll)
|
|
|
angles = torch.arccos(cosine)
|
|
|
angles_.append(angles)
|
|
|
return angles_
|
|
|
|
|
|
def get_letter_inds(self, letter_to_insert):
|
|
|
for group, l in zip(self.shape_groups, self.target_letter):
|
|
|
if l == letter_to_insert:
|
|
|
letter_inds = group.shape_ids
|
|
|
return letter_inds[0], letter_inds[-1], len(letter_inds)
|
|
|
|
|
|
def reset(self):
|
|
|
points = torch.cat([point.clone().detach() for point in self.parameters.point]).to(self.faces[0].device)
|
|
|
self.angles = self.get_angles(points)
|
|
|
|
|
|
def init_faces(self, device: torch.device) -> torch.tensor:
|
|
|
faces_ = []
|
|
|
for j, c in enumerate(self.target_letter):
|
|
|
points_np = [self.parameters.point[i].clone().detach().cpu().numpy() for i in range(len(self.parameters.point))]
|
|
|
start_ind, end_ind, shapes_per_letter = self.get_letter_inds(c)
|
|
|
print(c, start_ind, end_ind)
|
|
|
holes = []
|
|
|
if shapes_per_letter > 1:
|
|
|
holes = points_np[start_ind+1:end_ind]
|
|
|
poly = Polygon(points_np[start_ind], holes=holes)
|
|
|
poly = poly.buffer(0)
|
|
|
points_np = np.concatenate(points_np)
|
|
|
faces = Delaunay(points_np).simplices
|
|
|
is_intersect = np.array([poly.contains(Point(points_np[face].mean(0))) for face in faces], dtype=bool)
|
|
|
faces_.append(torch.from_numpy(faces[is_intersect]).to(device, dtype=torch.int64))
|
|
|
return faces_
|
|
|
|
|
|
def __call__(self) -> torch.Tensor:
|
|
|
loss_angles = 0
|
|
|
points = torch.cat(self.parameters.point).to(self.faces[0].device)
|
|
|
angles = self.get_angles(points)
|
|
|
for i in range(len(self.faces)):
|
|
|
loss_angles += (nnf.mse_loss(angles[i], self.angles[i]))
|
|
|
return loss_angles
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|