|
|
from typing import Mapping
|
|
|
import os
|
|
|
import base64
|
|
|
from tqdm import tqdm
|
|
|
from easydict import EasyDict as edict
|
|
|
import matplotlib.pyplot as plt
|
|
|
import torch
|
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
|
import pydiffvg
|
|
|
import save_svg
|
|
|
from losses import SDSLoss, ToneLoss, ConformalLoss
|
|
|
from config import set_config
|
|
|
from utils import (
|
|
|
check_and_create_dir,
|
|
|
get_data_augs,
|
|
|
save_image,
|
|
|
preprocess,
|
|
|
learning_rate_decay,
|
|
|
combine_word,
|
|
|
create_video,
|
|
|
)
|
|
|
import wandb
|
|
|
import warnings
|
|
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
pydiffvg.set_print_timing(False)
|
|
|
gamma = 1.0
|
|
|
|
|
|
|
|
|
def init_shapes(svg_path: str, trainable: Mapping[str, bool]):
|
|
|
"""Load the initial SVG, mark trainable points, return shapes & params."""
|
|
|
|
|
|
svg = f"{svg_path}.svg"
|
|
|
_, _, shapes_init, shape_groups_init = pydiffvg.svg_to_scene(svg)
|
|
|
|
|
|
parameters = edict()
|
|
|
|
|
|
if trainable.point:
|
|
|
parameters.point = []
|
|
|
for path in shapes_init:
|
|
|
path.points.requires_grad = True
|
|
|
parameters.point.append(path.points)
|
|
|
|
|
|
return shapes_init, shape_groups_init, parameters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_word_image(cfg, device: torch.device):
|
|
|
"""Optimise a single word and return path to the resulting PNG."""
|
|
|
|
|
|
|
|
|
if isinstance(cfg, dict):
|
|
|
cfg = edict(cfg)
|
|
|
|
|
|
pydiffvg.set_use_gpu(device.type == "cuda")
|
|
|
|
|
|
print("preprocessing")
|
|
|
preprocess(cfg.font, cfg.word, cfg.optimized_letter, cfg.level_of_cc)
|
|
|
|
|
|
if cfg.loss.use_sds_loss:
|
|
|
sds_loss = SDSLoss(cfg, device)
|
|
|
|
|
|
h = w = cfg.render_size
|
|
|
data_augs = get_data_augs(cfg.cut_size)
|
|
|
render = pydiffvg.RenderFunction.apply
|
|
|
|
|
|
print("initializing shape")
|
|
|
shapes, shape_groups, parameters = init_shapes(svg_path=cfg.target, trainable=cfg.trainable)
|
|
|
|
|
|
scene_args = pydiffvg.RenderFunction.serialize_scene(w, h, shapes, shape_groups)
|
|
|
img_init = render(w, h, 2, 2, 0, None, *scene_args)
|
|
|
img_init = img_init[:, :, 3:4] * img_init[:, :, :3] + torch.ones_like(img_init[:, :, :3]) * (1 - img_init[:, :, 3:4])
|
|
|
img_init = img_init[:, :, :3]
|
|
|
|
|
|
if cfg.use_wandb:
|
|
|
plt.imshow(img_init.detach().cpu())
|
|
|
wandb.log({"init": wandb.Image(plt)}, step=0)
|
|
|
plt.close()
|
|
|
|
|
|
if cfg.loss.tone.use_tone_loss:
|
|
|
tone_loss = ToneLoss(cfg)
|
|
|
tone_loss.set_image_init(img_init)
|
|
|
|
|
|
if cfg.save.init:
|
|
|
print("saving init")
|
|
|
filename = os.path.join(cfg.experiment_dir, "svg-init", "init.svg")
|
|
|
check_and_create_dir(filename)
|
|
|
save_svg.save_svg(filename, w, h, shapes, shape_groups)
|
|
|
|
|
|
num_iter = cfg.num_iter
|
|
|
optim = torch.optim.Adam([
|
|
|
{"params": parameters["point"], "lr": cfg.lr_base["point"]}
|
|
|
], betas=(0.9, 0.9), eps=1e-6)
|
|
|
|
|
|
if cfg.loss.conformal.use_conformal_loss:
|
|
|
conformal_loss = ConformalLoss(parameters, device, cfg.optimized_letter, shape_groups)
|
|
|
|
|
|
lr_lambda = lambda step: learning_rate_decay(
|
|
|
step,
|
|
|
cfg.lr.lr_init,
|
|
|
cfg.lr.lr_final,
|
|
|
num_iter,
|
|
|
lr_delay_steps=cfg.lr.lr_delay_steps,
|
|
|
lr_delay_mult=cfg.lr.lr_delay_mult,
|
|
|
) / cfg.lr.lr_init
|
|
|
|
|
|
scheduler = LambdaLR(optim, lr_lambda=lr_lambda, last_epoch=-1)
|
|
|
|
|
|
print("start training")
|
|
|
for step in tqdm(range(num_iter)):
|
|
|
if cfg.use_wandb:
|
|
|
wandb.log({"learning_rate": optim.param_groups[0]["lr"]}, step=step)
|
|
|
optim.zero_grad()
|
|
|
|
|
|
scene_args = pydiffvg.RenderFunction.serialize_scene(w, h, shapes, shape_groups)
|
|
|
img = render(w, h, 2, 2, step, None, *scene_args)
|
|
|
img = img[:, :, 3:4] * img[:, :, :3] + torch.ones_like(img[:, :, :3]) * (1 - img[:, :, 3:4])
|
|
|
img = img[:, :, :3]
|
|
|
|
|
|
if cfg.save.video and (step % cfg.save.video_frame_freq == 0 or step == num_iter - 1):
|
|
|
save_image(img, os.path.join(cfg.experiment_dir, "video-png", f"iter{step:04d}.png"), gamma)
|
|
|
svg_frame = os.path.join(cfg.experiment_dir, "video-svg", f"iter{step:04d}.svg")
|
|
|
check_and_create_dir(svg_frame)
|
|
|
save_svg.save_svg(svg_frame, w, h, shapes, shape_groups)
|
|
|
if cfg.use_wandb:
|
|
|
plt.imshow(img.detach().cpu())
|
|
|
wandb.log({"img": wandb.Image(plt)}, step=step)
|
|
|
plt.close()
|
|
|
|
|
|
x = img.unsqueeze(0).permute(0, 3, 1, 2).repeat(cfg.batch_size, 1, 1, 1)
|
|
|
x_aug = data_augs.forward(x)
|
|
|
|
|
|
loss = sds_loss(x_aug)
|
|
|
if cfg.loss.tone.use_tone_loss:
|
|
|
loss = loss + tone_loss(x, step)
|
|
|
if cfg.loss.conformal.use_conformal_loss:
|
|
|
loss = loss + cfg.loss.conformal.angeles_w * conformal_loss()
|
|
|
|
|
|
if cfg.use_wandb:
|
|
|
wandb.log({"total_loss": loss.item()}, step=step)
|
|
|
|
|
|
loss.backward()
|
|
|
optim.step()
|
|
|
scheduler.step()
|
|
|
|
|
|
svg_out = os.path.join(cfg.experiment_dir, "output-svg", "output.svg")
|
|
|
check_and_create_dir(svg_out)
|
|
|
save_svg.save_svg(svg_out, w, h, shapes, shape_groups)
|
|
|
|
|
|
combine_word(cfg.word, cfg.optimized_letter, cfg.font, cfg.experiment_dir)
|
|
|
|
|
|
if cfg.save.image:
|
|
|
png_out = os.path.join(cfg.experiment_dir, "output-png", "output.png")
|
|
|
check_and_create_dir(png_out)
|
|
|
pydiffvg.imwrite(img.detach().cpu(), png_out, gamma=gamma)
|
|
|
if cfg.use_wandb:
|
|
|
plt.imshow(img.detach().cpu())
|
|
|
wandb.log({"img": wandb.Image(plt)}, step=num_iter)
|
|
|
plt.close()
|
|
|
else:
|
|
|
png_out = ""
|
|
|
|
|
|
if cfg.save.video:
|
|
|
print("saving video")
|
|
|
create_video(cfg.num_iter, cfg.experiment_dir, cfg.save.video_frame_freq)
|
|
|
|
|
|
if cfg.use_wandb:
|
|
|
wandb.finish()
|
|
|
|
|
|
return os.path.abspath(png_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cli_entry():
|
|
|
cfg = set_config()
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
generate_word_image(cfg, device)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
cli_entry()
|
|
|
|