Spaces:
Running
on
L4
Running
on
L4
| #!/usr/bin/env python | |
| from argparse import ArgumentParser, Namespace | |
| import pickle | |
| import jax | |
| from jax import jit | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from PIL import Image | |
| from model import build_thera | |
| from utils import make_grid, interpolate_grid | |
| MEAN = jnp.array([.4488, .4371, .4040]) | |
| VAR = jnp.array([.25, .25, .25]) | |
| PATCH_SIZE = 256 | |
| def process_single(source, apply_encoder, apply_decoder, params, target_shape): | |
| t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None] | |
| coords_nearest = jnp.asarray(make_grid(target_shape)[None]) | |
| source_up = interpolate_grid(coords_nearest, source[None]) | |
| source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None] | |
| encoding = apply_encoder(params, source) | |
| coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) # global sampling coords | |
| out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32) | |
| for h_min in range(0, coords.shape[1], PATCH_SIZE): | |
| h_max = min(h_min + PATCH_SIZE, coords.shape[1]) | |
| for w_min in range(0, coords.shape[2], PATCH_SIZE): | |
| # apply decoder with one patch of coordinates | |
| w_max = min(w_min + PATCH_SIZE, coords.shape[2]) | |
| coords_patch = coords[:, h_min:h_max, w_min:w_max] | |
| out_patch = apply_decoder(params, encoding, coords_patch, t) | |
| out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch) | |
| out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None] | |
| out += source_up | |
| return out | |
| def process(source, model, params, target_shape, do_ensemble=True): | |
| apply_encoder = jit(model.apply_encoder) | |
| apply_decoder = jit(model.apply_decoder) | |
| outs = [] | |
| for i_rot in range(4 if do_ensemble else 1): | |
| source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2)) | |
| target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape | |
| out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_) | |
| outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3))) | |
| out = jnp.stack(outs).mean(0).clip(0., 1.) | |
| return jnp.rint(out[0] * 255).astype(jnp.uint8) | |
| def main(args: Namespace): | |
| source = np.asarray(Image.open(args.in_file)) / 255. | |
| if args.scale is not None: | |
| if args.size is not None: | |
| raise ValueError('Cannot specify both size and scale') | |
| target_shape = ( | |
| round(source.shape[0] * args.scale), | |
| round(source.shape[1] * args.scale), | |
| ) | |
| elif args.size is not None: | |
| target_shape = args.size | |
| else: | |
| raise ValueError('Must specify either size or scale') | |
| with open(args.checkpoint, 'rb') as fh: | |
| check = pickle.load(fh) | |
| params, backbone, size = check['model'], check['backbone'], check['size'] | |
| model = build_thera(3, backbone, size) | |
| out = process(source, model, params, target_shape, not args.no_ensemble) | |
| Image.fromarray(np.asarray(out)).save(args.out_file) | |
| def parse_args() -> Namespace: | |
| parser = ArgumentParser() | |
| parser.add_argument('in_file') | |
| parser.add_argument('out_file') | |
| parser.add_argument('--scale', type=float, help='Scale factor for super-resolution') | |
| parser.add_argument('--size', type=int, nargs=2, | |
| help='Target size (h, w), mutually exclusive with --scale') | |
| parser.add_argument('--checkpoint', help='Path to checkpoint file') | |
| parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble') | |
| return parser.parse_args() | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| main(args) | |