shuzi-mewtant's picture
Update DBNet model files
4b01fef verified
import numpy as np
import cv2
import torch
from PIL import Image
from transformers import Pipeline, AutoImageProcessor
from torchvision.ops import nms as torch_nms
from typing import List, Dict, Any, Tuple
try:
from .dbnet_constants import (
BOX_EXPAND_RATIO,
NMS_IOU_THRESHOLD,
POLY_EXPAND_RATIO,
SCORE_THRESHOLD,
SHRINK_THRESHOLD,
)
except ImportError: # pragma: no cover - fallback when running as a script
from dbnet_constants import ( # type: ignore
BOX_EXPAND_RATIO,
NMS_IOU_THRESHOLD,
POLY_EXPAND_RATIO,
SCORE_THRESHOLD,
SHRINK_THRESHOLD,
)
BoxXYWH = List[float]
def scale_boxes_back_xywh(
boxes_xywh: List[BoxXYWH],
transform_info: Dict[str, Any],
) -> List[List[int]]:
if not boxes_xywh:
return []
scale_factor = float(transform_info["scale_factor"])
orig_w, orig_h = transform_info["original_size"]
stamp_w, stamp_h = transform_info["stamp_size"]
inv_scale = 1.0 / scale_factor
mapped = []
for box in boxes_xywh:
x, y, w, h = box
x, y, w, h = float(x), float(y), float(w), float(h)
if w <= 0.0 or h <= 0.0:
continue
x1, y1 = x, y
x2, y2 = x + w, y + h
inside_stamp = (
x1 >= 0.0 and y1 >= 0.0 and
x2 <= float(stamp_w) and y2 <= float(stamp_h)
)
if inside_stamp:
x_orig, y_orig, w_orig, h_orig = x, y, w, h
else:
x_orig = x * inv_scale
y_orig = y * inv_scale
w_orig = w * inv_scale
h_orig = h * inv_scale
x_orig = max(0.0, min(x_orig, orig_w))
y_orig = max(0.0, min(y_orig, orig_h))
w_orig = max(0.0, min(w_orig, orig_w - x_orig))
h_orig = max(0.0, min(h_orig, orig_h - y_orig))
if w_orig <= 0.0 or h_orig <= 0.0:
continue
mapped.append([
int(round(x_orig)),
int(round(y_orig)),
int(round(w_orig)),
int(round(h_orig)),
])
return mapped
def xywh_to_xyxy(box: BoxXYWH) -> List[float]:
x, y, w, h = box
return [x, y, x + w, y + h]
def xyxy_to_xywh(box: List[float]) -> BoxXYWH:
x1, y1, x2, y2 = box
return [x1, y1, x2 - x1, y2 - y1]
def nms_xywh_with_scores(
boxes: List[BoxXYWH],
scores: List[float],
iou_threshold: float,
device="cpu"
) -> Tuple[List[BoxXYWH], List[float]]:
if not boxes:
return [], []
boxes_xyxy = np.array([xywh_to_xyxy(b) for b in boxes], dtype=np.float32)
scores_np = np.array(scores, dtype=np.float32)
widths = boxes_xyxy[:, 2] - boxes_xyxy[:, 0]
heights = boxes_xyxy[:, 3] - boxes_xyxy[:, 1]
areas = widths * heights
keep_mask = areas > 0
boxes_xyxy = boxes_xyxy[keep_mask]
scores_np = scores_np[keep_mask]
if len(boxes_xyxy) == 0:
return [], []
boxes_t = torch.from_numpy(boxes_xyxy).to(device)
scores_t = torch.from_numpy(scores_np).to(device)
keep_indices = torch_nms(boxes_t, scores_t, iou_threshold)
keep_indices_np = keep_indices.cpu().numpy()
kept_xyxy = boxes_xyxy[keep_indices_np]
kept_scores = scores_np[keep_indices_np].tolist()
kept_boxes = [xyxy_to_xywh(b) for b in kept_xyxy]
return kept_boxes, kept_scores
def expand_box_xywh(
x: float, y: float, w: float, h: float,
img_w: float, img_h: float, ratio: float
) -> List[int]:
if ratio <= 0.0 or w <= 0.0 or h <= 0.0:
return [int(round(x)), int(round(y)), int(round(w)), int(round(h))]
dx = w * ratio
dy = h * ratio
new_x = max(0.0, x - dx / 2.0)
new_y = max(0.0, y - dy / 2.0)
new_w = min(img_w - new_x, w + dx)
new_h = min(img_h - new_y, h + dy)
if new_w <= 0.0 or new_h <= 0.0:
return []
return [
int(round(new_x)), int(round(new_y)), int(round(new_w)), int(round(new_h))
]
# ============================================================
# PIPELINE CLASS
# ============================================================
class DBNetPipeline(Pipeline):
_load_image_processor = True
def __init__(self, model, tokenizer=None, feature_extractor=None, image_processor=None, **kwargs):
super().__init__(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
**kwargs,
)
if self.image_processor is None:
processor_repo = getattr(self.model.config, "_name_or_path", None)
if processor_repo:
try:
self.image_processor = AutoImageProcessor.from_pretrained(
processor_repo,
trust_remote_code=True,
)
except Exception as exc:
raise ValueError(
f"Failed to load image processor for repo '{processor_repo}'. "
"Pass an initialized `DBNetImageProcessor` when creating the pipeline."
) from exc
if self.image_processor is None:
raise ValueError(
"DBNetPipeline requires an image processor. "
"Ensure `DBNetImageProcessor` is available and passed to the pipeline."
)
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
postprocess_kwargs = {}
# Pass through any relevant kwargs if needed
return preprocess_kwargs, {}, postprocess_kwargs
def preprocess(self, image, **kwargs):
# Handle different input types
# If it's a numpy array, convert to PIL Image
if isinstance(image, np.ndarray):
# Ensure it's in the right format (H, W, C) and uint8
if image.dtype != np.uint8:
if image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
# Handle grayscale
if image.ndim == 2:
image = np.stack([image] * 3, axis=-1)
elif image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
# Convert to PIL
image = Image.fromarray(image)
return self.image_processor(images=image, return_tensors="pt")
def _forward(self, model_inputs):
pixel_values = model_inputs["pixel_values"]
with torch.no_grad():
outputs = self.model(pixel_values)
return {
"logits": outputs,
"transform_info": model_inputs["transform_info"],
"original_size": model_inputs["original_size"]
}
def postprocess(self, model_outputs, **kwargs):
preds = model_outputs["logits"] # [B, 3, H, W]
transform_info_list = model_outputs["transform_info"]
original_size_list = model_outputs["original_size"]
batch_results = []
# Iterate over batch
for i in range(preds.shape[0]):
shrink_map = preds[i, 0].cpu().numpy()
transform_info = transform_info_list[i]
orig_w, orig_h = original_size_list[i]
H, W = shrink_map.shape
mask = (shrink_map > SHRINK_THRESHOLD).astype(np.uint8) * 255
contours, _ = cv2.findContours(
mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
)
boxes_xywh = []
scores = []
for cnt in contours:
if len(cnt) < 3:
continue
if POLY_EXPAND_RATIO > 0.0:
cnt_f = cnt.astype(np.float32)
pts = cnt_f.reshape(-1, 2)
cx, cy = pts.mean(axis=0)
pts_exp = (
(pts - np.array([cx, cy], dtype=np.float32))
* (1.0 + POLY_EXPAND_RATIO)
+ np.array([cx, cy], dtype=np.float32)
)
cnt_exp = pts_exp.reshape(-1, 1, 2).astype(np.int32)
x, y, w, h = cv2.boundingRect(cnt_exp)
else:
x, y, w, h = cv2.boundingRect(cnt)
if w <= 0 or h <= 0:
continue
y1 = max(0, y)
x1 = max(0, x)
y2 = min(H, y + h)
x2 = min(W, x + w)
region = shrink_map[y1:y2, x1:x2]
if region.size == 0:
continue
score = float(region.mean())
boxes_xywh.append([float(x), float(y), float(w), float(h)])
scores.append(score)
# NMS in padded space
device = preds.device
boxes_nms, scores_nms = nms_xywh_with_scores(
boxes_xywh, scores, NMS_IOU_THRESHOLD, device=device
)
# Map back to original coords
boxes_orig_xywh = scale_boxes_back_xywh(boxes_nms, transform_info)
# NMS again in original coords
boxes_after_nms, scores_after_nms = nms_xywh_with_scores(
boxes_orig_xywh, scores_nms, NMS_IOU_THRESHOLD, device=device
)
final_results = []
for (x, y, w, h), score in zip(boxes_after_nms, scores_after_nms):
if score < SCORE_THRESHOLD:
continue
expanded = expand_box_xywh(
x, y, w, h, orig_w, orig_h, BOX_EXPAND_RATIO
)
if not expanded:
continue
ex, ey, ew, eh = expanded
# Format for ObjectDetectionPipeline
final_results.append({
"score": score,
"label": "text",
"box": {
"xmin": int(ex),
"ymin": int(ey),
"xmax": int(ex + ew),
"ymax": int(ey + eh)
}
})
batch_results.append(final_results)
return batch_results