from typing import Mapping import os import base64 # still here if you need it later from tqdm import tqdm from easydict import EasyDict as edict import matplotlib.pyplot as plt import torch from torch.optim.lr_scheduler import LambdaLR import pydiffvg import save_svg from losses import SDSLoss, ToneLoss, ConformalLoss from config import set_config from utils import ( check_and_create_dir, get_data_augs, save_image, preprocess, learning_rate_decay, combine_word, create_video, ) import wandb import warnings warnings.filterwarnings("ignore") pydiffvg.set_print_timing(False) gamma = 1.0 def init_shapes(svg_path: str, trainable: Mapping[str, bool]): """Load the initial SVG, mark trainable points, return shapes & params.""" svg = f"{svg_path}.svg" _, _, shapes_init, shape_groups_init = pydiffvg.svg_to_scene(svg) parameters = edict() if trainable.point: parameters.point = [] for path in shapes_init: path.points.requires_grad = True parameters.point.append(path.points) return shapes_init, shape_groups_init, parameters # ----------------------------------------------------------------------------- # Public entry‑point that CLI *and* FastAPI reuse # ----------------------------------------------------------------------------- def generate_word_image(cfg, device: torch.device): """Optimise a single word and return path to the resulting PNG.""" # make sure we can access attributes no matter if `cfg` is dict or EasyDict if isinstance(cfg, dict): cfg = edict(cfg) pydiffvg.set_use_gpu(device.type == "cuda") print("preprocessing") preprocess(cfg.font, cfg.word, cfg.optimized_letter, cfg.level_of_cc) if cfg.loss.use_sds_loss: sds_loss = SDSLoss(cfg, device) h = w = cfg.render_size data_augs = get_data_augs(cfg.cut_size) render = pydiffvg.RenderFunction.apply print("initializing shape") shapes, shape_groups, parameters = init_shapes(svg_path=cfg.target, trainable=cfg.trainable) scene_args = pydiffvg.RenderFunction.serialize_scene(w, h, shapes, shape_groups) img_init = render(w, h, 2, 2, 0, None, *scene_args) img_init = img_init[:, :, 3:4] * img_init[:, :, :3] + torch.ones_like(img_init[:, :, :3]) * (1 - img_init[:, :, 3:4]) img_init = img_init[:, :, :3] if cfg.use_wandb: plt.imshow(img_init.detach().cpu()) wandb.log({"init": wandb.Image(plt)}, step=0) plt.close() if cfg.loss.tone.use_tone_loss: tone_loss = ToneLoss(cfg) tone_loss.set_image_init(img_init) if cfg.save.init: print("saving init") filename = os.path.join(cfg.experiment_dir, "svg-init", "init.svg") check_and_create_dir(filename) save_svg.save_svg(filename, w, h, shapes, shape_groups) num_iter = cfg.num_iter optim = torch.optim.Adam([ {"params": parameters["point"], "lr": cfg.lr_base["point"]} ], betas=(0.9, 0.9), eps=1e-6) if cfg.loss.conformal.use_conformal_loss: conformal_loss = ConformalLoss(parameters, device, cfg.optimized_letter, shape_groups) lr_lambda = lambda step: learning_rate_decay( step, cfg.lr.lr_init, cfg.lr.lr_final, num_iter, lr_delay_steps=cfg.lr.lr_delay_steps, lr_delay_mult=cfg.lr.lr_delay_mult, ) / cfg.lr.lr_init scheduler = LambdaLR(optim, lr_lambda=lr_lambda, last_epoch=-1) print("start training") for step in tqdm(range(num_iter)): if cfg.use_wandb: wandb.log({"learning_rate": optim.param_groups[0]["lr"]}, step=step) optim.zero_grad() scene_args = pydiffvg.RenderFunction.serialize_scene(w, h, shapes, shape_groups) img = render(w, h, 2, 2, step, None, *scene_args) img = img[:, :, 3:4] * img[:, :, :3] + torch.ones_like(img[:, :, :3]) * (1 - img[:, :, 3:4]) img = img[:, :, :3] if cfg.save.video and (step % cfg.save.video_frame_freq == 0 or step == num_iter - 1): save_image(img, os.path.join(cfg.experiment_dir, "video-png", f"iter{step:04d}.png"), gamma) svg_frame = os.path.join(cfg.experiment_dir, "video-svg", f"iter{step:04d}.svg") check_and_create_dir(svg_frame) save_svg.save_svg(svg_frame, w, h, shapes, shape_groups) if cfg.use_wandb: plt.imshow(img.detach().cpu()) wandb.log({"img": wandb.Image(plt)}, step=step) plt.close() x = img.unsqueeze(0).permute(0, 3, 1, 2).repeat(cfg.batch_size, 1, 1, 1) x_aug = data_augs.forward(x) loss = sds_loss(x_aug) if cfg.loss.tone.use_tone_loss: loss = loss + tone_loss(x, step) if cfg.loss.conformal.use_conformal_loss: loss = loss + cfg.loss.conformal.angeles_w * conformal_loss() if cfg.use_wandb: wandb.log({"total_loss": loss.item()}, step=step) loss.backward() optim.step() scheduler.step() svg_out = os.path.join(cfg.experiment_dir, "output-svg", "output.svg") check_and_create_dir(svg_out) save_svg.save_svg(svg_out, w, h, shapes, shape_groups) combine_word(cfg.word, cfg.optimized_letter, cfg.font, cfg.experiment_dir) if cfg.save.image: png_out = os.path.join(cfg.experiment_dir, "output-png", "output.png") check_and_create_dir(png_out) pydiffvg.imwrite(img.detach().cpu(), png_out, gamma=gamma) if cfg.use_wandb: plt.imshow(img.detach().cpu()) wandb.log({"img": wandb.Image(plt)}, step=num_iter) plt.close() else: png_out = "" if cfg.save.video: print("saving video") create_video(cfg.num_iter, cfg.experiment_dir, cfg.save.video_frame_freq) if cfg.use_wandb: wandb.finish() return os.path.abspath(png_out) # ----------------------------------------------------------------------------- # CLI entry‑point – original behaviour when run directly # ----------------------------------------------------------------------------- def cli_entry(): cfg = set_config() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") generate_word_image(cfg, device) if __name__ == "__main__": cli_entry()