import numpy as np import cv2 from pathlib import Path from typing import Any, Dict, List, Tuple, Union, Optional from PIL import Image, ImageOps from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_utils import to_numpy_array try: from .dbnet_constants import IMAGE_TARGET_SIZE, STAMP_HEIGHT, STAMP_WIDTH except ImportError: # pragma: no cover - fallback for local execution from dbnet_constants import IMAGE_TARGET_SIZE, STAMP_HEIGHT, STAMP_WIDTH # type: ignore class DBNetImageProcessor(BaseImageProcessor): model_input_names = ["pixel_values"] def __init__( self, do_resize: bool = True, size: Optional[Dict[str, int]] = None, do_rescale: bool = False, rescale_factor: Optional[float] = 1.0, do_normalize: bool = True, image_mean: Optional[Tuple[float, float, float]] = (123.675, 116.28, 103.53), image_std: Optional[Tuple[float, float, float]] = (58.395, 57.12, 57.375), do_pad: bool = True, pad_size: Optional[Dict[str, int]] = None, stamp_size: Optional[Dict[str, int]] = None, **kwargs, ): super().__init__(**kwargs) self.do_resize = do_resize self.size = size or {"height": IMAGE_TARGET_SIZE, "width": IMAGE_TARGET_SIZE} self.do_rescale = do_rescale self.rescale_factor = rescale_factor if rescale_factor is not None else 1.0 self.do_normalize = do_normalize self.image_mean = np.asarray(image_mean or (0.485, 0.456, 0.406), dtype=np.float32) self.image_std = np.asarray(image_std or (0.229, 0.224, 0.225), dtype=np.float32) self.do_pad = do_pad self.pad_size = pad_size or {"height": IMAGE_TARGET_SIZE, "width": IMAGE_TARGET_SIZE} self.stamp_size = stamp_size or {"width": STAMP_WIDTH, "height": STAMP_HEIGHT} def __call__( self, images: Union[Image.Image, np.ndarray, str, Path, List[Union[Image.Image, np.ndarray, str, Path]]], return_tensors: Optional[str] = None, **kwargs, ) -> BatchFeature: if not isinstance(images, (list, tuple)): images = [images] pixel_values: List[np.ndarray] = [] transform_infos: List[Dict[str, Any]] = [] original_sizes: List[Tuple[int, int]] = [] for image in images: img_rgb, orig_size = self._load_image(image) processed, transform_info = self._resize_and_pad_with_stamp(img_rgb) processed = processed.astype(np.float32) if self.do_rescale: processed *= self.rescale_factor if self.do_normalize: processed = (processed - self.image_mean) / self.image_std processed = np.transpose(processed, (2, 0, 1)) pixel_values.append(processed) transform_infos.append(transform_info) original_sizes.append(orig_size) data: Dict[str, Any] = { "transform_info": transform_infos, "original_size": original_sizes, } if return_tensors == "pt": try: import torch except ImportError as exc: raise ImportError("PyTorch is required when return_tensors='pt'.") from exc data["pixel_values"] = torch.from_numpy(np.stack(pixel_values, axis=0)) elif return_tensors == "np": data["pixel_values"] = np.stack(pixel_values, axis=0) else: data["pixel_values"] = pixel_values return BatchFeature(data=data, tensor_type=None) def _load_image(self, image: Union[Image.Image, np.ndarray, str, Path]) -> Tuple[np.ndarray, Tuple[int, int]]: if isinstance(image, (str, Path)): image = Image.open(image) if isinstance(image, Image.Image): img = ImageOps.exif_transpose(image).convert("RGB") return np.array(img), (img.width, img.height) arr = to_numpy_array(image, rescale=False) if arr.ndim == 2: arr = np.stack([arr] * 3, axis=-1) if arr.shape[-1] == 1: arr = np.repeat(arr, 3, axis=-1) if arr.shape[-1] == 4: arr = arr[..., :3] arr = arr.astype(np.uint8) return arr, (arr.shape[1], arr.shape[0]) def _resize_and_pad_with_stamp(self, img_rgb: np.ndarray) -> Tuple[np.ndarray, Dict[str, Any]]: target_h = self.size["height"] target_w = self.size["width"] orig_h, orig_w = img_rgb.shape[:2] if self.do_resize: if orig_w > orig_h: scale_factor = target_w / float(orig_w) new_w = target_w new_h = int(round(orig_h * scale_factor)) else: scale_factor = target_h / float(orig_h) new_w = int(round(orig_w * scale_factor)) new_h = target_h new_w = max(1, new_w) new_h = max(1, new_h) resized = cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA) else: scale_factor = 1.0 resized = img_rgb.copy() new_h, new_w = orig_h, orig_w pad_h = self.pad_size["height"] pad_w = self.pad_size["width"] padded = np.zeros((pad_h, pad_w, 3), dtype=np.uint8) padded[:new_h, :new_w, :] = resized stamp_w = min(self.stamp_size["width"], orig_w, pad_w) stamp_h = min(self.stamp_size["height"], orig_h, pad_h) if stamp_w > 0 and stamp_h > 0: padded[0:stamp_h, 0:stamp_w, :] = img_rgb[0:stamp_h, 0:stamp_w, :] transform_info = { "scale_factor": float(scale_factor), "original_size": (orig_w, orig_h), "scaled_size": (new_w, new_h), "padded_size": (pad_w, pad_h), "stamp_size": (stamp_w, stamp_h), } return padded, transform_info