|
|
|
|
|
import base64, sys, os, torch, tempfile, logging |
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "code")) |
|
|
|
|
|
from code.config import set_config |
|
|
from code.main import generate_word_image |
|
|
from easydict import EasyDict |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s | %(levelname)s | %(message)s", |
|
|
datefmt="%H:%M:%S", |
|
|
) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
KNOWN_CLI_KEYS = { |
|
|
"word", |
|
|
"optimized_letter", |
|
|
"font", |
|
|
"seed", |
|
|
"experiment", |
|
|
"use_wandb", |
|
|
"wandb_user", |
|
|
} |
|
|
|
|
|
def _sanitize(cfg): |
|
|
""" |
|
|
Recursively walk an EasyDict / dict and replace every Ellipsis (`...`) |
|
|
with None. Returns the same object (in-place). |
|
|
""" |
|
|
if isinstance(cfg, dict): |
|
|
for k, v in cfg.items(): |
|
|
if v is Ellipsis: |
|
|
cfg[k] = None |
|
|
else: |
|
|
_sanitize(v) |
|
|
return cfg |
|
|
|
|
|
|
|
|
def handler(payload: dict) -> str: |
|
|
|
|
|
init_dir = os.path.join("code", "data", "init") |
|
|
if not os.path.isdir(init_dir): |
|
|
os.makedirs(init_dir, exist_ok=True) |
|
|
logging.warning("Created missing directory %s", init_dir) |
|
|
|
|
|
|
|
|
tmp_token_path = None |
|
|
if "token" in payload: |
|
|
tmp_token_path = tempfile.NamedTemporaryFile(delete=False, dir=".", prefix="TOKEN_", mode="w") |
|
|
tmp_token_path.write(payload["token"]) |
|
|
tmp_token_path.close() |
|
|
os.environ["TOKEN_FILE"] = tmp_token_path.name |
|
|
|
|
|
|
|
|
os.symlink(tmp_token_path.name, "TOKEN") |
|
|
|
|
|
|
|
|
cli_argv = [sys.argv[0]] |
|
|
for k in KNOWN_CLI_KEYS & payload.keys(): |
|
|
cli_argv += [f"--{k}", str(payload[k])] |
|
|
|
|
|
orig_argv = sys.argv[:] |
|
|
try: |
|
|
sys.argv = cli_argv |
|
|
cfg = set_config() |
|
|
finally: |
|
|
sys.argv = orig_argv |
|
|
if tmp_token_path: |
|
|
os.remove("TOKEN") |
|
|
os.unlink(tmp_token_path.name) |
|
|
|
|
|
_sanitize(cfg) |
|
|
|
|
|
|
|
|
expected_svg = os.path.join( |
|
|
"code", "data", "init", |
|
|
f"{cfg.font}_{cfg.optimized_letter}_scaled.svg" |
|
|
) |
|
|
if not os.path.exists(expected_svg): |
|
|
logging.warning("Expected seed SVG not found: %s (will be generated)", expected_svg) |
|
|
|
|
|
|
|
|
for k, v in payload.items(): |
|
|
setattr(cfg, k, v) |
|
|
|
|
|
init_dir = os.path.join("code", "data", "init") |
|
|
os.makedirs(init_dir, exist_ok=True) |
|
|
|
|
|
out_path = generate_word_image(cfg, device) |
|
|
|
|
|
|
|
|
cfg.render_size = getattr(cfg, "render_size", 384) |
|
|
cfg.word = cfg.word.upper() |
|
|
cfg.optimized_letter = getattr(cfg, "optimized_letter", cfg.word[-1]) |
|
|
if getattr(cfg.diffusion, "model", ...) is Ellipsis: |
|
|
cfg.diffusion.model = "runwayml/stable-diffusion-v1-5" |
|
|
|
|
|
|
|
|
out_path = generate_word_image(cfg, device) |
|
|
|
|
|
|
|
|
with open(out_path, "rb") as f: |
|
|
return base64.b64encode(f.read()).decode() |
|
|
|
|
|
|