File size: 3,557 Bytes
5e3465d
d0a8a94
5e3465d
 
 
 
 
 
d0a8a94
 
 
 
 
 
5e3465d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4ae355
 
d0a8a94
 
 
 
 
 
801d57b
 
 
 
 
 
 
 
 
 
5e3465d
 
 
 
 
 
 
 
 
 
 
801d57b
 
 
5e3465d
 
9e16acb
 
 
 
 
 
 
 
5e3465d
 
 
 
d0a8a94
 
 
 
5e3465d
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# --- wai_service.py (final handler) ---------------------------------
import base64, sys, os, torch, tempfile, logging
sys.path.append(os.path.join(os.path.dirname(__file__), "code"))

from code.config import set_config
from code.main import generate_word_image
from easydict import EasyDict

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    datefmt="%H:%M:%S",
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# flags that *are* recognised by code/config.parse_args()
KNOWN_CLI_KEYS = {
    "word",
    "optimized_letter",
    "font",
    "seed",
    "experiment",
    "use_wandb",
    "wandb_user",
}

def _sanitize(cfg):
    """
    Recursively walk an EasyDict / dict and replace every Ellipsis (`...`)
    with None.  Returns the same object (in-place).
    """
    if isinstance(cfg, dict):
        for k, v in cfg.items():
            if v is Ellipsis:
                cfg[k] = None
            else:
                _sanitize(v)
    return cfg


def handler(payload: dict) -> str:
    logging.info("handler received job (word=%s, has_token=%s)",
                 payload.get("word"), "token" in payload)
    init_dir = os.path.join("code", "data", "init")
    if not os.path.isdir(init_dir):
        os.makedirs(init_dir, exist_ok=True)
        logging.warning("Created missing directory %s", init_dir)


    tmp_token_path = None                           # ➋
    if "token" in payload:
        tmp_token_path = tempfile.NamedTemporaryFile(delete=False, dir=".", prefix="TOKEN_", mode="w")
        tmp_token_path.write(payload["token"])
        tmp_token_path.close()
        os.environ["TOKEN_FILE"] = tmp_token_path.name  # optional: let you debug

        # make config.py look for our temp file
        os.symlink(tmp_token_path.name, "TOKEN")
        
    # 1️⃣  Build fake argv *only* from recognised keys
    cli_argv = [sys.argv[0]]
    for k in KNOWN_CLI_KEYS & payload.keys():
        cli_argv += [f"--{k}", str(payload[k])]

    orig_argv = sys.argv[:]
    try:
        sys.argv = cli_argv
        cfg = set_config()           # EasyDict with YAML + CLI-compatible fields
    finally:
        sys.argv = orig_argv
        if tmp_token_path:        # ➌ clean-up
            os.remove("TOKEN")            # remove symlink
            os.unlink(tmp_token_path.name) # remove temp file

    _sanitize(cfg)
    
        # The SVG path that preprocessing expects
    expected_svg = os.path.join(
        "code", "data", "init",
        f"{cfg.font}_{cfg.optimized_letter}_scaled.svg"
    )
    if not os.path.exists(expected_svg):
        logging.warning("Expected seed SVG not found: %s (will be generated)", expected_svg)

    # 2️⃣  Overlay ALL payload keys (new ones like render_size stick)
    for k, v in payload.items():
        setattr(cfg, k, v)
        
    init_dir = os.path.join("code", "data", "init")
    os.makedirs(init_dir, exist_ok=True)

    # sensible defaults
    cfg.render_size = getattr(cfg, "render_size", 384)   # <= set here
    cfg.word = cfg.word.upper()
    cfg.optimized_letter = getattr(cfg, "optimized_letter", cfg.word[-1])
    if getattr(cfg.diffusion, "model", ...) is Ellipsis:
        cfg.diffusion.model = "runwayml/stable-diffusion-v1-5"

    # 3️⃣  Run optimisation
    out_path = generate_word_image(cfg, device)

    # 4️⃣  Return base-64
    with open(out_path, "rb") as f:
        return base64.b64encode(f.read()).decode()
# --------------------------------------------------------------------