akhaliq3
commited on
Commit
·
4a7bfa8
1
Parent(s):
506da10
app file
Browse files
app.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
from matplotlib import gridspec
|
| 5 |
+
from matplotlib import pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import urllib
|
| 9 |
+
import tensorflow as tf
|
| 10 |
+
import gradio as gr
|
| 11 |
+
from subprocess import call
|
| 12 |
+
import sys
|
| 13 |
+
import requests
|
| 14 |
+
url1 = 'https://cdn.pixabay.com/photo/2014/09/07/21/52/city-438393_1280.jpg'
|
| 15 |
+
r = requests.get(url1, allow_redirects=True)
|
| 16 |
+
open("city1.jpg", 'wb').write(r.content)
|
| 17 |
+
url2 = 'https://cdn.pixabay.com/photo/2016/02/19/11/36/canal-1209808_1280.jpg'
|
| 18 |
+
r = requests.get(url2, allow_redirects=True)
|
| 19 |
+
open("city2.jpg", 'wb').write(r.content)
|
| 20 |
+
DatasetInfo = collections.namedtuple(
|
| 21 |
+
'DatasetInfo',
|
| 22 |
+
'num_classes, label_divisor, thing_list, colormap, class_names')
|
| 23 |
+
def _cityscapes_label_colormap():
|
| 24 |
+
"""Creates a label colormap used in CITYSCAPES segmentation benchmark.
|
| 25 |
+
See more about CITYSCAPES dataset at https://www.cityscapes-dataset.com/
|
| 26 |
+
M. Cordts, et al. "The Cityscapes Dataset for Semantic Urban Scene Understanding." CVPR. 2016.
|
| 27 |
+
Returns:
|
| 28 |
+
A 2-D numpy array with each row being mapped RGB color (in uint8 range).
|
| 29 |
+
"""
|
| 30 |
+
colormap = np.zeros((256, 3), dtype=np.uint8)
|
| 31 |
+
colormap[0] = [128, 64, 128]
|
| 32 |
+
colormap[1] = [244, 35, 232]
|
| 33 |
+
colormap[2] = [70, 70, 70]
|
| 34 |
+
colormap[3] = [102, 102, 156]
|
| 35 |
+
colormap[4] = [190, 153, 153]
|
| 36 |
+
colormap[5] = [153, 153, 153]
|
| 37 |
+
colormap[6] = [250, 170, 30]
|
| 38 |
+
colormap[7] = [220, 220, 0]
|
| 39 |
+
colormap[8] = [107, 142, 35]
|
| 40 |
+
colormap[9] = [152, 251, 152]
|
| 41 |
+
colormap[10] = [70, 130, 180]
|
| 42 |
+
colormap[11] = [220, 20, 60]
|
| 43 |
+
colormap[12] = [255, 0, 0]
|
| 44 |
+
colormap[13] = [0, 0, 142]
|
| 45 |
+
colormap[14] = [0, 0, 70]
|
| 46 |
+
colormap[15] = [0, 60, 100]
|
| 47 |
+
colormap[16] = [0, 80, 100]
|
| 48 |
+
colormap[17] = [0, 0, 230]
|
| 49 |
+
colormap[18] = [119, 11, 32]
|
| 50 |
+
return colormap
|
| 51 |
+
def _cityscapes_class_names():
|
| 52 |
+
return ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
| 53 |
+
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
| 54 |
+
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
| 55 |
+
'bicycle')
|
| 56 |
+
def cityscapes_dataset_information():
|
| 57 |
+
return DatasetInfo(
|
| 58 |
+
num_classes=19,
|
| 59 |
+
label_divisor=1000,
|
| 60 |
+
thing_list=tuple(range(11, 19)),
|
| 61 |
+
colormap=_cityscapes_label_colormap(),
|
| 62 |
+
class_names=_cityscapes_class_names())
|
| 63 |
+
def perturb_color(color, noise, used_colors, max_trials=50, random_state=None):
|
| 64 |
+
"""Pertrubs the color with some noise.
|
| 65 |
+
If `used_colors` is not None, we will return the color that has
|
| 66 |
+
not appeared before in it.
|
| 67 |
+
Args:
|
| 68 |
+
color: A numpy array with three elements [R, G, B].
|
| 69 |
+
noise: Integer, specifying the amount of perturbing noise (in uint8 range).
|
| 70 |
+
used_colors: A set, used to keep track of used colors.
|
| 71 |
+
max_trials: An integer, maximum trials to generate random color.
|
| 72 |
+
random_state: An optional np.random.RandomState. If passed, will be used to
|
| 73 |
+
generate random numbers.
|
| 74 |
+
Returns:
|
| 75 |
+
A perturbed color that has not appeared in used_colors.
|
| 76 |
+
"""
|
| 77 |
+
if random_state is None:
|
| 78 |
+
random_state = np.random
|
| 79 |
+
for _ in range(max_trials):
|
| 80 |
+
random_color = color + random_state.randint(
|
| 81 |
+
low=-noise, high=noise + 1, size=3)
|
| 82 |
+
random_color = np.clip(random_color, 0, 255)
|
| 83 |
+
if tuple(random_color) not in used_colors:
|
| 84 |
+
used_colors.add(tuple(random_color))
|
| 85 |
+
return random_color
|
| 86 |
+
print('Max trial reached and duplicate color will be used. Please consider '
|
| 87 |
+
'increase noise in `perturb_color()`.')
|
| 88 |
+
return random_color
|
| 89 |
+
def color_panoptic_map(panoptic_prediction, dataset_info, perturb_noise):
|
| 90 |
+
"""Helper method to colorize output panoptic map.
|
| 91 |
+
Args:
|
| 92 |
+
panoptic_prediction: A 2D numpy array, panoptic prediction from deeplab
|
| 93 |
+
model.
|
| 94 |
+
dataset_info: A DatasetInfo object, dataset associated to the model.
|
| 95 |
+
perturb_noise: Integer, the amount of noise (in uint8 range) added to each
|
| 96 |
+
instance of the same semantic class.
|
| 97 |
+
Returns:
|
| 98 |
+
colored_panoptic_map: A 3D numpy array with last dimension of 3, colored
|
| 99 |
+
panoptic prediction map.
|
| 100 |
+
used_colors: A dictionary mapping semantic_ids to a set of colors used
|
| 101 |
+
in `colored_panoptic_map`.
|
| 102 |
+
"""
|
| 103 |
+
if panoptic_prediction.ndim != 2:
|
| 104 |
+
raise ValueError('Expect 2-D panoptic prediction. Got {}'.format(
|
| 105 |
+
panoptic_prediction.shape))
|
| 106 |
+
semantic_map = panoptic_prediction // dataset_info.label_divisor
|
| 107 |
+
instance_map = panoptic_prediction % dataset_info.label_divisor
|
| 108 |
+
height, width = panoptic_prediction.shape
|
| 109 |
+
colored_panoptic_map = np.zeros((height, width, 3), dtype=np.uint8)
|
| 110 |
+
used_colors = collections.defaultdict(set)
|
| 111 |
+
# Use a fixed seed to reproduce the same visualization.
|
| 112 |
+
random_state = np.random.RandomState(0)
|
| 113 |
+
unique_semantic_ids = np.unique(semantic_map)
|
| 114 |
+
for semantic_id in unique_semantic_ids:
|
| 115 |
+
semantic_mask = semantic_map == semantic_id
|
| 116 |
+
if semantic_id in dataset_info.thing_list:
|
| 117 |
+
# For `thing` class, we will add a small amount of random noise to its
|
| 118 |
+
# correspondingly predefined semantic segmentation colormap.
|
| 119 |
+
unique_instance_ids = np.unique(instance_map[semantic_mask])
|
| 120 |
+
for instance_id in unique_instance_ids:
|
| 121 |
+
instance_mask = np.logical_and(semantic_mask,
|
| 122 |
+
instance_map == instance_id)
|
| 123 |
+
random_color = perturb_color(
|
| 124 |
+
dataset_info.colormap[semantic_id],
|
| 125 |
+
perturb_noise,
|
| 126 |
+
used_colors[semantic_id],
|
| 127 |
+
random_state=random_state)
|
| 128 |
+
colored_panoptic_map[instance_mask] = random_color
|
| 129 |
+
else:
|
| 130 |
+
# For `stuff` class, we use the defined semantic color.
|
| 131 |
+
colored_panoptic_map[semantic_mask] = dataset_info.colormap[semantic_id]
|
| 132 |
+
used_colors[semantic_id].add(tuple(dataset_info.colormap[semantic_id]))
|
| 133 |
+
return colored_panoptic_map, used_colors
|
| 134 |
+
def vis_segmentation(image,
|
| 135 |
+
panoptic_prediction,
|
| 136 |
+
dataset_info,
|
| 137 |
+
perturb_noise=60):
|
| 138 |
+
"""Visualizes input image, segmentation map and overlay view."""
|
| 139 |
+
plt.figure(figsize=(30, 20))
|
| 140 |
+
grid_spec = gridspec.GridSpec(2, 2)
|
| 141 |
+
ax = plt.subplot(grid_spec[0])
|
| 142 |
+
plt.imshow(image)
|
| 143 |
+
plt.axis('off')
|
| 144 |
+
ax.set_title('input image', fontsize=20)
|
| 145 |
+
ax = plt.subplot(grid_spec[1])
|
| 146 |
+
panoptic_map, used_colors = color_panoptic_map(panoptic_prediction,
|
| 147 |
+
dataset_info, perturb_noise)
|
| 148 |
+
plt.imshow(panoptic_map)
|
| 149 |
+
plt.axis('off')
|
| 150 |
+
ax.set_title('panoptic map', fontsize=20)
|
| 151 |
+
ax = plt.subplot(grid_spec[2])
|
| 152 |
+
plt.imshow(image)
|
| 153 |
+
plt.imshow(panoptic_map, alpha=0.7)
|
| 154 |
+
plt.axis('off')
|
| 155 |
+
ax.set_title('panoptic overlay', fontsize=20)
|
| 156 |
+
ax = plt.subplot(grid_spec[3])
|
| 157 |
+
max_num_instances = max(len(color) for color in used_colors.values())
|
| 158 |
+
# RGBA image as legend.
|
| 159 |
+
legend = np.zeros((len(used_colors), max_num_instances, 4), dtype=np.uint8)
|
| 160 |
+
class_names = []
|
| 161 |
+
for i, semantic_id in enumerate(sorted(used_colors)):
|
| 162 |
+
legend[i, :len(used_colors[semantic_id]), :3] = np.array(
|
| 163 |
+
list(used_colors[semantic_id]))
|
| 164 |
+
legend[i, :len(used_colors[semantic_id]), 3] = 255
|
| 165 |
+
if semantic_id < dataset_info.num_classes:
|
| 166 |
+
class_names.append(dataset_info.class_names[semantic_id])
|
| 167 |
+
else:
|
| 168 |
+
class_names.append('ignore')
|
| 169 |
+
plt.imshow(legend, interpolation='nearest')
|
| 170 |
+
ax.yaxis.tick_left()
|
| 171 |
+
plt.yticks(range(len(legend)), class_names, fontsize=15)
|
| 172 |
+
plt.xticks([], [])
|
| 173 |
+
ax.tick_params(width=0.0, grid_linewidth=0.0)
|
| 174 |
+
plt.grid('off')
|
| 175 |
+
return plt
|
| 176 |
+
def run_cmd(command):
|
| 177 |
+
try:
|
| 178 |
+
print(command)
|
| 179 |
+
call(command, shell=True)
|
| 180 |
+
except KeyboardInterrupt:
|
| 181 |
+
print("Process interrupted")
|
| 182 |
+
sys.exit(1)
|
| 183 |
+
MODEL_NAME = 'resnet50_os32_panoptic_deeplab_cityscapes_crowd_trainfine_saved_model'
|
| 184 |
+
_MODELS = ('resnet50_os32_panoptic_deeplab_cityscapes_crowd_trainfine_saved_model',
|
| 185 |
+
'resnet50_beta_os32_panoptic_deeplab_cityscapes_trainfine_saved_model',
|
| 186 |
+
'wide_resnet41_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
|
| 187 |
+
'swidernet_sac_1_1_1_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
|
| 188 |
+
'swidernet_sac_1_1_3_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
|
| 189 |
+
'swidernet_sac_1_1_4.5_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
|
| 190 |
+
'axial_swidernet_1_1_1_os16_axial_deeplab_cityscapes_trainfine_saved_model',
|
| 191 |
+
'axial_swidernet_1_1_3_os16_axial_deeplab_cityscapes_trainfine_saved_model',
|
| 192 |
+
'axial_swidernet_1_1_4.5_os16_axial_deeplab_cityscapes_trainfine_saved_model',
|
| 193 |
+
'max_deeplab_s_backbone_os16_axial_deeplab_cityscapes_trainfine_saved_model',
|
| 194 |
+
'max_deeplab_l_backbone_os16_axial_deeplab_cityscapes_trainfine_saved_model')
|
| 195 |
+
_DOWNLOAD_URL_PATTERN = 'https://storage.googleapis.com/gresearch/tf-deeplab/saved_model/%s.tar.gz'
|
| 196 |
+
_MODEL_NAME_TO_URL_AND_DATASET = {
|
| 197 |
+
model: (_DOWNLOAD_URL_PATTERN % model, cityscapes_dataset_information())
|
| 198 |
+
for model in _MODELS
|
| 199 |
+
}
|
| 200 |
+
MODEL_URL, DATASET_INFO = _MODEL_NAME_TO_URL_AND_DATASET[MODEL_NAME]
|
| 201 |
+
model_dir = tempfile.mkdtemp()
|
| 202 |
+
download_path = os.path.join(model_dir, MODEL_NAME + '.gz')
|
| 203 |
+
urllib.request.urlretrieve(MODEL_URL, download_path)
|
| 204 |
+
run_cmd("tar -xzvf " + download_path + " -C " + model_dir)
|
| 205 |
+
LOADED_MODEL = tf.saved_model.load(os.path.join(model_dir, MODEL_NAME))
|
| 206 |
+
def inference(image):
|
| 207 |
+
image = image.resize(size=(512, 512))
|
| 208 |
+
im = np.array(image)
|
| 209 |
+
output = LOADED_MODEL(tf.cast(im, tf.uint8))
|
| 210 |
+
return vis_segmentation(im, output['panoptic_pred'][0], DATASET_INFO)
|
| 211 |
+
title = "Deeplab2"
|
| 212 |
+
description = "demo for Deeplab2. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
|
| 213 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2106.09748'>DeepLab2: A TensorFlow Library for Deep Labeling</a> | <a href='https://github.com/google-research/deeplab2'>Github Repo</a></p>"
|
| 214 |
+
gr.Interface(
|
| 215 |
+
inference,
|
| 216 |
+
[gr.inputs.Image(type="pil", label="Input")],
|
| 217 |
+
gr.outputs.Image(type="plot", label="Output"),
|
| 218 |
+
title=title,
|
| 219 |
+
description=description,
|
| 220 |
+
article=article,
|
| 221 |
+
examples=[
|
| 222 |
+
["city1.jpg"],
|
| 223 |
+
["city2.jpg"]
|
| 224 |
+
]).launch()
|