word-as-image-api / wai_service.py
KingHacker9000's picture
Fix cfg-before-assignment; move expected_svg debug block
9e16acb
raw
history blame
3.48 kB
# --- wai_service.py (final handler) ---------------------------------
import base64, sys, os, torch, tempfile, logging
sys.path.append(os.path.join(os.path.dirname(__file__), "code"))
from code.config import set_config
from code.main import generate_word_image
from easydict import EasyDict
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s",
datefmt="%H:%M:%S",
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# flags that *are* recognised by code/config.parse_args()
KNOWN_CLI_KEYS = {
"word",
"optimized_letter",
"font",
"seed",
"experiment",
"use_wandb",
"wandb_user",
}
def _sanitize(cfg):
"""
Recursively walk an EasyDict / dict and replace every Ellipsis (`...`)
with None. Returns the same object (in-place).
"""
if isinstance(cfg, dict):
for k, v in cfg.items():
if v is Ellipsis:
cfg[k] = None
else:
_sanitize(v)
return cfg
def handler(payload: dict) -> str:
init_dir = os.path.join("code", "data", "init")
if not os.path.isdir(init_dir):
os.makedirs(init_dir, exist_ok=True)
logging.warning("Created missing directory %s", init_dir)
tmp_token_path = None # ➋
if "token" in payload:
tmp_token_path = tempfile.NamedTemporaryFile(delete=False, dir=".", prefix="TOKEN_", mode="w")
tmp_token_path.write(payload["token"])
tmp_token_path.close()
os.environ["TOKEN_FILE"] = tmp_token_path.name # optional: let you debug
# make config.py look for our temp file
os.symlink(tmp_token_path.name, "TOKEN")
# 1️⃣ Build fake argv *only* from recognised keys
cli_argv = [sys.argv[0]]
for k in KNOWN_CLI_KEYS & payload.keys():
cli_argv += [f"--{k}", str(payload[k])]
orig_argv = sys.argv[:]
try:
sys.argv = cli_argv
cfg = set_config() # EasyDict with YAML + CLI-compatible fields
finally:
sys.argv = orig_argv
if tmp_token_path: # ➌ clean-up
os.remove("TOKEN") # remove symlink
os.unlink(tmp_token_path.name) # remove temp file
_sanitize(cfg)
# The SVG path that preprocessing expects
expected_svg = os.path.join(
"code", "data", "init",
f"{cfg.font}_{cfg.optimized_letter}_scaled.svg"
)
if not os.path.exists(expected_svg):
logging.warning("Expected seed SVG not found: %s (will be generated)", expected_svg)
# 2️⃣ Overlay ALL payload keys (new ones like render_size stick)
for k, v in payload.items():
setattr(cfg, k, v)
init_dir = os.path.join("code", "data", "init")
os.makedirs(init_dir, exist_ok=True)
out_path = generate_word_image(cfg, device)
# sensible defaults
cfg.render_size = getattr(cfg, "render_size", 384) # <= set here
cfg.word = cfg.word.upper()
cfg.optimized_letter = getattr(cfg, "optimized_letter", cfg.word[-1])
if getattr(cfg.diffusion, "model", ...) is Ellipsis:
cfg.diffusion.model = "runwayml/stable-diffusion-v1-5"
# 3️⃣ Run optimisation
out_path = generate_word_image(cfg, device)
# 4️⃣ Return base-64
with open(out_path, "rb") as f:
return base64.b64encode(f.read()).decode()
# --------------------------------------------------------------------