Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,7 +17,7 @@ import random
|
|
| 17 |
import spaces
|
| 18 |
import gradio as gr
|
| 19 |
|
| 20 |
-
mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth')
|
| 21 |
mobile_sam.eval()
|
| 22 |
mobile_predictor = SamPredictor(mobile_sam)
|
| 23 |
colors = [(255, 0, 0), (0, 255, 0)]
|
|
@@ -269,6 +269,7 @@ with block:
|
|
| 269 |
for p, l in sel_pix:
|
| 270 |
points.append(p)
|
| 271 |
labels.append(l)
|
|
|
|
| 272 |
mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
|
| 273 |
with torch.no_grad():
|
| 274 |
masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)
|
|
|
|
| 17 |
import spaces
|
| 18 |
import gradio as gr
|
| 19 |
|
| 20 |
+
mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth')
|
| 21 |
mobile_sam.eval()
|
| 22 |
mobile_predictor = SamPredictor(mobile_sam)
|
| 23 |
colors = [(255, 0, 0), (0, 255, 0)]
|
|
|
|
| 269 |
for p, l in sel_pix:
|
| 270 |
points.append(p)
|
| 271 |
labels.append(l)
|
| 272 |
+
mobile_predictor=mobile_predictor.to("cuda")
|
| 273 |
mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
|
| 274 |
with torch.no_grad():
|
| 275 |
masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)
|