# --- wai_service.py (final handler) --------------------------------- import base64, sys, os, torch, tempfile, logging sys.path.append(os.path.join(os.path.dirname(__file__), "code")) from code.config import set_config from code.main import generate_word_image from easydict import EasyDict logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s", datefmt="%H:%M:%S", ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # flags that *are* recognised by code/config.parse_args() KNOWN_CLI_KEYS = { "word", "optimized_letter", "font", "seed", "experiment", "use_wandb", "wandb_user", } def _sanitize(cfg): """ Recursively walk an EasyDict / dict and replace every Ellipsis (`...`) with None. Returns the same object (in-place). """ if isinstance(cfg, dict): for k, v in cfg.items(): if v is Ellipsis: cfg[k] = None else: _sanitize(v) return cfg def handler(payload: dict) -> str: init_dir = os.path.join("code", "data", "init") if not os.path.isdir(init_dir): os.makedirs(init_dir, exist_ok=True) logging.warning("Created missing directory %s", init_dir) tmp_token_path = None # ➋ if "token" in payload: tmp_token_path = tempfile.NamedTemporaryFile(delete=False, dir=".", prefix="TOKEN_", mode="w") tmp_token_path.write(payload["token"]) tmp_token_path.close() os.environ["TOKEN_FILE"] = tmp_token_path.name # optional: let you debug # make config.py look for our temp file os.symlink(tmp_token_path.name, "TOKEN") # 1️⃣ Build fake argv *only* from recognised keys cli_argv = [sys.argv[0]] for k in KNOWN_CLI_KEYS & payload.keys(): cli_argv += [f"--{k}", str(payload[k])] orig_argv = sys.argv[:] try: sys.argv = cli_argv cfg = set_config() # EasyDict with YAML + CLI-compatible fields finally: sys.argv = orig_argv if tmp_token_path: # ➌ clean-up os.remove("TOKEN") # remove symlink os.unlink(tmp_token_path.name) # remove temp file _sanitize(cfg) # The SVG path that preprocessing expects expected_svg = os.path.join( "code", "data", "init", f"{cfg.font}_{cfg.optimized_letter}_scaled.svg" ) if not os.path.exists(expected_svg): logging.warning("Expected seed SVG not found: %s (will be generated)", expected_svg) # 2️⃣ Overlay ALL payload keys (new ones like render_size stick) for k, v in payload.items(): setattr(cfg, k, v) init_dir = os.path.join("code", "data", "init") os.makedirs(init_dir, exist_ok=True) out_path = generate_word_image(cfg, device) # sensible defaults cfg.render_size = getattr(cfg, "render_size", 384) # <= set here cfg.word = cfg.word.upper() cfg.optimized_letter = getattr(cfg, "optimized_letter", cfg.word[-1]) if getattr(cfg.diffusion, "model", ...) is Ellipsis: cfg.diffusion.model = "runwayml/stable-diffusion-v1-5" # 3️⃣ Run optimisation out_path = generate_word_image(cfg, device) # 4️⃣ Return base-64 with open(out_path, "rb") as f: return base64.b64encode(f.read()).decode() # --------------------------------------------------------------------