|
|
import numpy as np |
|
|
import cv2 |
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import Pipeline, AutoImageProcessor |
|
|
from torchvision.ops import nms as torch_nms |
|
|
from typing import List, Dict, Any, Tuple |
|
|
|
|
|
try: |
|
|
from .dbnet_constants import ( |
|
|
BOX_EXPAND_RATIO, |
|
|
NMS_IOU_THRESHOLD, |
|
|
POLY_EXPAND_RATIO, |
|
|
SCORE_THRESHOLD, |
|
|
SHRINK_THRESHOLD, |
|
|
) |
|
|
except ImportError: |
|
|
from dbnet_constants import ( |
|
|
BOX_EXPAND_RATIO, |
|
|
NMS_IOU_THRESHOLD, |
|
|
POLY_EXPAND_RATIO, |
|
|
SCORE_THRESHOLD, |
|
|
SHRINK_THRESHOLD, |
|
|
) |
|
|
|
|
|
BoxXYWH = List[float] |
|
|
|
|
|
def scale_boxes_back_xywh( |
|
|
boxes_xywh: List[BoxXYWH], |
|
|
transform_info: Dict[str, Any], |
|
|
) -> List[List[int]]: |
|
|
if not boxes_xywh: |
|
|
return [] |
|
|
|
|
|
scale_factor = float(transform_info["scale_factor"]) |
|
|
orig_w, orig_h = transform_info["original_size"] |
|
|
stamp_w, stamp_h = transform_info["stamp_size"] |
|
|
|
|
|
inv_scale = 1.0 / scale_factor |
|
|
mapped = [] |
|
|
|
|
|
for box in boxes_xywh: |
|
|
x, y, w, h = box |
|
|
x, y, w, h = float(x), float(y), float(w), float(h) |
|
|
if w <= 0.0 or h <= 0.0: |
|
|
continue |
|
|
|
|
|
x1, y1 = x, y |
|
|
x2, y2 = x + w, y + h |
|
|
|
|
|
inside_stamp = ( |
|
|
x1 >= 0.0 and y1 >= 0.0 and |
|
|
x2 <= float(stamp_w) and y2 <= float(stamp_h) |
|
|
) |
|
|
|
|
|
if inside_stamp: |
|
|
x_orig, y_orig, w_orig, h_orig = x, y, w, h |
|
|
else: |
|
|
x_orig = x * inv_scale |
|
|
y_orig = y * inv_scale |
|
|
w_orig = w * inv_scale |
|
|
h_orig = h * inv_scale |
|
|
|
|
|
x_orig = max(0.0, min(x_orig, orig_w)) |
|
|
y_orig = max(0.0, min(y_orig, orig_h)) |
|
|
w_orig = max(0.0, min(w_orig, orig_w - x_orig)) |
|
|
h_orig = max(0.0, min(h_orig, orig_h - y_orig)) |
|
|
if w_orig <= 0.0 or h_orig <= 0.0: |
|
|
continue |
|
|
|
|
|
mapped.append([ |
|
|
int(round(x_orig)), |
|
|
int(round(y_orig)), |
|
|
int(round(w_orig)), |
|
|
int(round(h_orig)), |
|
|
]) |
|
|
|
|
|
return mapped |
|
|
|
|
|
def xywh_to_xyxy(box: BoxXYWH) -> List[float]: |
|
|
x, y, w, h = box |
|
|
return [x, y, x + w, y + h] |
|
|
|
|
|
def xyxy_to_xywh(box: List[float]) -> BoxXYWH: |
|
|
x1, y1, x2, y2 = box |
|
|
return [x1, y1, x2 - x1, y2 - y1] |
|
|
|
|
|
def nms_xywh_with_scores( |
|
|
boxes: List[BoxXYWH], |
|
|
scores: List[float], |
|
|
iou_threshold: float, |
|
|
device="cpu" |
|
|
) -> Tuple[List[BoxXYWH], List[float]]: |
|
|
if not boxes: |
|
|
return [], [] |
|
|
|
|
|
boxes_xyxy = np.array([xywh_to_xyxy(b) for b in boxes], dtype=np.float32) |
|
|
scores_np = np.array(scores, dtype=np.float32) |
|
|
|
|
|
widths = boxes_xyxy[:, 2] - boxes_xyxy[:, 0] |
|
|
heights = boxes_xyxy[:, 3] - boxes_xyxy[:, 1] |
|
|
areas = widths * heights |
|
|
keep_mask = areas > 0 |
|
|
boxes_xyxy = boxes_xyxy[keep_mask] |
|
|
scores_np = scores_np[keep_mask] |
|
|
|
|
|
if len(boxes_xyxy) == 0: |
|
|
return [], [] |
|
|
|
|
|
boxes_t = torch.from_numpy(boxes_xyxy).to(device) |
|
|
scores_t = torch.from_numpy(scores_np).to(device) |
|
|
|
|
|
keep_indices = torch_nms(boxes_t, scores_t, iou_threshold) |
|
|
keep_indices_np = keep_indices.cpu().numpy() |
|
|
|
|
|
kept_xyxy = boxes_xyxy[keep_indices_np] |
|
|
kept_scores = scores_np[keep_indices_np].tolist() |
|
|
kept_boxes = [xyxy_to_xywh(b) for b in kept_xyxy] |
|
|
|
|
|
return kept_boxes, kept_scores |
|
|
|
|
|
def expand_box_xywh( |
|
|
x: float, y: float, w: float, h: float, |
|
|
img_w: float, img_h: float, ratio: float |
|
|
) -> List[int]: |
|
|
if ratio <= 0.0 or w <= 0.0 or h <= 0.0: |
|
|
return [int(round(x)), int(round(y)), int(round(w)), int(round(h))] |
|
|
|
|
|
dx = w * ratio |
|
|
dy = h * ratio |
|
|
|
|
|
new_x = max(0.0, x - dx / 2.0) |
|
|
new_y = max(0.0, y - dy / 2.0) |
|
|
new_w = min(img_w - new_x, w + dx) |
|
|
new_h = min(img_h - new_y, h + dy) |
|
|
|
|
|
if new_w <= 0.0 or new_h <= 0.0: |
|
|
return [] |
|
|
|
|
|
return [ |
|
|
int(round(new_x)), int(round(new_y)), int(round(new_w)), int(round(new_h)) |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DBNetPipeline(Pipeline): |
|
|
_load_image_processor = True |
|
|
|
|
|
def __init__(self, model, tokenizer=None, feature_extractor=None, image_processor=None, **kwargs): |
|
|
super().__init__( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
feature_extractor=feature_extractor, |
|
|
image_processor=image_processor, |
|
|
**kwargs, |
|
|
) |
|
|
if self.image_processor is None: |
|
|
processor_repo = getattr(self.model.config, "_name_or_path", None) |
|
|
if processor_repo: |
|
|
try: |
|
|
self.image_processor = AutoImageProcessor.from_pretrained( |
|
|
processor_repo, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
except Exception as exc: |
|
|
raise ValueError( |
|
|
f"Failed to load image processor for repo '{processor_repo}'. " |
|
|
"Pass an initialized `DBNetImageProcessor` when creating the pipeline." |
|
|
) from exc |
|
|
if self.image_processor is None: |
|
|
raise ValueError( |
|
|
"DBNetPipeline requires an image processor. " |
|
|
"Ensure `DBNetImageProcessor` is available and passed to the pipeline." |
|
|
) |
|
|
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
|
preprocess_kwargs = {} |
|
|
postprocess_kwargs = {} |
|
|
|
|
|
|
|
|
return preprocess_kwargs, {}, postprocess_kwargs |
|
|
|
|
|
def preprocess(self, image, **kwargs): |
|
|
|
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
|
|
|
if image.dtype != np.uint8: |
|
|
if image.max() <= 1.0: |
|
|
image = (image * 255).astype(np.uint8) |
|
|
else: |
|
|
image = image.astype(np.uint8) |
|
|
|
|
|
|
|
|
if image.ndim == 2: |
|
|
image = np.stack([image] * 3, axis=-1) |
|
|
elif image.shape[-1] == 1: |
|
|
image = np.repeat(image, 3, axis=-1) |
|
|
|
|
|
|
|
|
image = Image.fromarray(image) |
|
|
|
|
|
return self.image_processor(images=image, return_tensors="pt") |
|
|
|
|
|
def _forward(self, model_inputs): |
|
|
pixel_values = model_inputs["pixel_values"] |
|
|
with torch.no_grad(): |
|
|
outputs = self.model(pixel_values) |
|
|
return { |
|
|
"logits": outputs, |
|
|
"transform_info": model_inputs["transform_info"], |
|
|
"original_size": model_inputs["original_size"] |
|
|
} |
|
|
|
|
|
def postprocess(self, model_outputs, **kwargs): |
|
|
preds = model_outputs["logits"] |
|
|
transform_info_list = model_outputs["transform_info"] |
|
|
original_size_list = model_outputs["original_size"] |
|
|
|
|
|
batch_results = [] |
|
|
|
|
|
|
|
|
for i in range(preds.shape[0]): |
|
|
shrink_map = preds[i, 0].cpu().numpy() |
|
|
transform_info = transform_info_list[i] |
|
|
orig_w, orig_h = original_size_list[i] |
|
|
|
|
|
H, W = shrink_map.shape |
|
|
|
|
|
mask = (shrink_map > SHRINK_THRESHOLD).astype(np.uint8) * 255 |
|
|
contours, _ = cv2.findContours( |
|
|
mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE |
|
|
) |
|
|
|
|
|
boxes_xywh = [] |
|
|
scores = [] |
|
|
|
|
|
for cnt in contours: |
|
|
if len(cnt) < 3: |
|
|
continue |
|
|
|
|
|
if POLY_EXPAND_RATIO > 0.0: |
|
|
cnt_f = cnt.astype(np.float32) |
|
|
pts = cnt_f.reshape(-1, 2) |
|
|
cx, cy = pts.mean(axis=0) |
|
|
pts_exp = ( |
|
|
(pts - np.array([cx, cy], dtype=np.float32)) |
|
|
* (1.0 + POLY_EXPAND_RATIO) |
|
|
+ np.array([cx, cy], dtype=np.float32) |
|
|
) |
|
|
cnt_exp = pts_exp.reshape(-1, 1, 2).astype(np.int32) |
|
|
x, y, w, h = cv2.boundingRect(cnt_exp) |
|
|
else: |
|
|
x, y, w, h = cv2.boundingRect(cnt) |
|
|
|
|
|
if w <= 0 or h <= 0: |
|
|
continue |
|
|
|
|
|
y1 = max(0, y) |
|
|
x1 = max(0, x) |
|
|
y2 = min(H, y + h) |
|
|
x2 = min(W, x + w) |
|
|
region = shrink_map[y1:y2, x1:x2] |
|
|
if region.size == 0: |
|
|
continue |
|
|
score = float(region.mean()) |
|
|
|
|
|
boxes_xywh.append([float(x), float(y), float(w), float(h)]) |
|
|
scores.append(score) |
|
|
|
|
|
|
|
|
device = preds.device |
|
|
boxes_nms, scores_nms = nms_xywh_with_scores( |
|
|
boxes_xywh, scores, NMS_IOU_THRESHOLD, device=device |
|
|
) |
|
|
|
|
|
|
|
|
boxes_orig_xywh = scale_boxes_back_xywh(boxes_nms, transform_info) |
|
|
|
|
|
|
|
|
boxes_after_nms, scores_after_nms = nms_xywh_with_scores( |
|
|
boxes_orig_xywh, scores_nms, NMS_IOU_THRESHOLD, device=device |
|
|
) |
|
|
|
|
|
final_results = [] |
|
|
|
|
|
for (x, y, w, h), score in zip(boxes_after_nms, scores_after_nms): |
|
|
if score < SCORE_THRESHOLD: |
|
|
continue |
|
|
|
|
|
expanded = expand_box_xywh( |
|
|
x, y, w, h, orig_w, orig_h, BOX_EXPAND_RATIO |
|
|
) |
|
|
if not expanded: |
|
|
continue |
|
|
|
|
|
ex, ey, ew, eh = expanded |
|
|
|
|
|
|
|
|
final_results.append({ |
|
|
"score": score, |
|
|
"label": "text", |
|
|
"box": { |
|
|
"xmin": int(ex), |
|
|
"ymin": int(ey), |
|
|
"xmax": int(ex + ew), |
|
|
"ymax": int(ey + eh) |
|
|
} |
|
|
}) |
|
|
|
|
|
batch_results.append(final_results) |
|
|
|
|
|
return batch_results |
|
|
|