Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,9 +15,8 @@ import torch
|
|
| 15 |
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
| 16 |
import random
|
| 17 |
import gradio as gr
|
| 18 |
-
import spaces
|
| 19 |
|
| 20 |
-
mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth')
|
| 21 |
mobile_sam.eval()
|
| 22 |
mobile_predictor = SamPredictor(mobile_sam)
|
| 23 |
colors = [(255, 0, 0), (0, 255, 0)]
|
|
@@ -74,7 +73,6 @@ def resize_image(input_image, resolution):
|
|
| 74 |
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
| 75 |
return img
|
| 76 |
|
| 77 |
-
@spaces.GPU
|
| 78 |
def process(input_image,
|
| 79 |
original_image,
|
| 80 |
original_mask,
|
|
@@ -275,7 +273,7 @@ with block:
|
|
| 275 |
for p, l in sel_pix:
|
| 276 |
points.append(p)
|
| 277 |
labels.append(l)
|
| 278 |
-
mobile_predictor=mobile_predictor
|
| 279 |
mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
|
| 280 |
with torch.no_grad():
|
| 281 |
masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)
|
|
|
|
| 15 |
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
| 16 |
import random
|
| 17 |
import gradio as gr
|
|
|
|
| 18 |
|
| 19 |
+
mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to("cuda")
|
| 20 |
mobile_sam.eval()
|
| 21 |
mobile_predictor = SamPredictor(mobile_sam)
|
| 22 |
colors = [(255, 0, 0), (0, 255, 0)]
|
|
|
|
| 73 |
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
| 74 |
return img
|
| 75 |
|
|
|
|
| 76 |
def process(input_image,
|
| 77 |
original_image,
|
| 78 |
original_mask,
|
|
|
|
| 273 |
for p, l in sel_pix:
|
| 274 |
points.append(p)
|
| 275 |
labels.append(l)
|
| 276 |
+
mobile_predictor=mobile_predictor
|
| 277 |
mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
|
| 278 |
with torch.no_grad():
|
| 279 |
masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)
|