KingHacker9000's picture
Re-commit with fonts & image under Git LFS
5e3465d
raw
history blame
6.52 kB
from typing import Mapping
import os
import base64 # still here if you need it later
from tqdm import tqdm
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
import torch
from torch.optim.lr_scheduler import LambdaLR
import pydiffvg
import save_svg
from losses import SDSLoss, ToneLoss, ConformalLoss
from config import set_config
from utils import (
check_and_create_dir,
get_data_augs,
save_image,
preprocess,
learning_rate_decay,
combine_word,
create_video,
)
import wandb
import warnings
warnings.filterwarnings("ignore")
pydiffvg.set_print_timing(False)
gamma = 1.0
def init_shapes(svg_path: str, trainable: Mapping[str, bool]):
"""Load the initial SVG, mark trainable points, return shapes & params."""
svg = f"{svg_path}.svg"
_, _, shapes_init, shape_groups_init = pydiffvg.svg_to_scene(svg)
parameters = edict()
if trainable.point:
parameters.point = []
for path in shapes_init:
path.points.requires_grad = True
parameters.point.append(path.points)
return shapes_init, shape_groups_init, parameters
# -----------------------------------------------------------------------------
# Public entry‑point that CLI *and* FastAPI reuse
# -----------------------------------------------------------------------------
def generate_word_image(cfg, device: torch.device):
"""Optimise a single word and return path to the resulting PNG."""
# make sure we can access attributes no matter if `cfg` is dict or EasyDict
if isinstance(cfg, dict):
cfg = edict(cfg)
pydiffvg.set_use_gpu(device.type == "cuda")
print("preprocessing")
preprocess(cfg.font, cfg.word, cfg.optimized_letter, cfg.level_of_cc)
if cfg.loss.use_sds_loss:
sds_loss = SDSLoss(cfg, device)
h = w = cfg.render_size
data_augs = get_data_augs(cfg.cut_size)
render = pydiffvg.RenderFunction.apply
print("initializing shape")
shapes, shape_groups, parameters = init_shapes(svg_path=cfg.target, trainable=cfg.trainable)
scene_args = pydiffvg.RenderFunction.serialize_scene(w, h, shapes, shape_groups)
img_init = render(w, h, 2, 2, 0, None, *scene_args)
img_init = img_init[:, :, 3:4] * img_init[:, :, :3] + torch.ones_like(img_init[:, :, :3]) * (1 - img_init[:, :, 3:4])
img_init = img_init[:, :, :3]
if cfg.use_wandb:
plt.imshow(img_init.detach().cpu())
wandb.log({"init": wandb.Image(plt)}, step=0)
plt.close()
if cfg.loss.tone.use_tone_loss:
tone_loss = ToneLoss(cfg)
tone_loss.set_image_init(img_init)
if cfg.save.init:
print("saving init")
filename = os.path.join(cfg.experiment_dir, "svg-init", "init.svg")
check_and_create_dir(filename)
save_svg.save_svg(filename, w, h, shapes, shape_groups)
num_iter = cfg.num_iter
optim = torch.optim.Adam([
{"params": parameters["point"], "lr": cfg.lr_base["point"]}
], betas=(0.9, 0.9), eps=1e-6)
if cfg.loss.conformal.use_conformal_loss:
conformal_loss = ConformalLoss(parameters, device, cfg.optimized_letter, shape_groups)
lr_lambda = lambda step: learning_rate_decay(
step,
cfg.lr.lr_init,
cfg.lr.lr_final,
num_iter,
lr_delay_steps=cfg.lr.lr_delay_steps,
lr_delay_mult=cfg.lr.lr_delay_mult,
) / cfg.lr.lr_init
scheduler = LambdaLR(optim, lr_lambda=lr_lambda, last_epoch=-1)
print("start training")
for step in tqdm(range(num_iter)):
if cfg.use_wandb:
wandb.log({"learning_rate": optim.param_groups[0]["lr"]}, step=step)
optim.zero_grad()
scene_args = pydiffvg.RenderFunction.serialize_scene(w, h, shapes, shape_groups)
img = render(w, h, 2, 2, step, None, *scene_args)
img = img[:, :, 3:4] * img[:, :, :3] + torch.ones_like(img[:, :, :3]) * (1 - img[:, :, 3:4])
img = img[:, :, :3]
if cfg.save.video and (step % cfg.save.video_frame_freq == 0 or step == num_iter - 1):
save_image(img, os.path.join(cfg.experiment_dir, "video-png", f"iter{step:04d}.png"), gamma)
svg_frame = os.path.join(cfg.experiment_dir, "video-svg", f"iter{step:04d}.svg")
check_and_create_dir(svg_frame)
save_svg.save_svg(svg_frame, w, h, shapes, shape_groups)
if cfg.use_wandb:
plt.imshow(img.detach().cpu())
wandb.log({"img": wandb.Image(plt)}, step=step)
plt.close()
x = img.unsqueeze(0).permute(0, 3, 1, 2).repeat(cfg.batch_size, 1, 1, 1)
x_aug = data_augs.forward(x)
loss = sds_loss(x_aug)
if cfg.loss.tone.use_tone_loss:
loss = loss + tone_loss(x, step)
if cfg.loss.conformal.use_conformal_loss:
loss = loss + cfg.loss.conformal.angeles_w * conformal_loss()
if cfg.use_wandb:
wandb.log({"total_loss": loss.item()}, step=step)
loss.backward()
optim.step()
scheduler.step()
svg_out = os.path.join(cfg.experiment_dir, "output-svg", "output.svg")
check_and_create_dir(svg_out)
save_svg.save_svg(svg_out, w, h, shapes, shape_groups)
combine_word(cfg.word, cfg.optimized_letter, cfg.font, cfg.experiment_dir)
if cfg.save.image:
png_out = os.path.join(cfg.experiment_dir, "output-png", "output.png")
check_and_create_dir(png_out)
pydiffvg.imwrite(img.detach().cpu(), png_out, gamma=gamma)
if cfg.use_wandb:
plt.imshow(img.detach().cpu())
wandb.log({"img": wandb.Image(plt)}, step=num_iter)
plt.close()
else:
png_out = ""
if cfg.save.video:
print("saving video")
create_video(cfg.num_iter, cfg.experiment_dir, cfg.save.video_frame_freq)
if cfg.use_wandb:
wandb.finish()
return os.path.abspath(png_out)
# -----------------------------------------------------------------------------
# CLI entry‑point – original behaviour when run directly
# -----------------------------------------------------------------------------
def cli_entry():
cfg = set_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generate_word_image(cfg, device)
if __name__ == "__main__":
cli_entry()