import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig class DBNetConfig(PretrainedConfig): model_type = "dbnet" def __init__( self, in_channels=3, # Backbone config (ResNet) backbone_layers=[2, 2, 2, 2], # ResNet-18: [2,2,2,2], ResNet-50: [3,4,6,3] backbone_base_channels=64, # Neck config (FPN) neck_lateral_channels=256, neck_out_channels=64, # Head config head_in_channels=256, # 64 * 4 k=50, **kwargs ): self.in_channels = in_channels self.backbone_layers = backbone_layers self.backbone_base_channels = backbone_base_channels self.neck_lateral_channels = neck_lateral_channels self.neck_out_channels = neck_out_channels self.head_in_channels = head_in_channels self.k = k super().__init__(**kwargs) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1): super().__init__() self.conv1 = nn.Conv2d( inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=1, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(planes) if stride != 1 or inplanes != planes: self.downsample = nn.Sequential( nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes), ) else: self.downsample = None self.relu = nn.ReLU(inplace=True) def forward(self, x): identity = x out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) if self.downsample is not None: identity = self.downsample(x) out = self.relu(out + identity) return out class ResNetBackbone(nn.Module): """ResNet backbone with configurable layers.""" def __init__(self, in_channels=3, base_channels=64, layers=[2, 2, 2, 2]): super().__init__() self.inplanes = base_channels self.conv1 = nn.Conv2d( in_channels, base_channels, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = nn.BatchNorm2d(base_channels) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # C2, C3, C4, C5 self.layer1 = self._make_layer(base_channels, blocks=layers[0], stride=1) self.layer2 = self._make_layer(base_channels * 2, blocks=layers[1], stride=2) self.layer3 = self._make_layer(base_channels * 4, blocks=layers[2], stride=2) self.layer4 = self._make_layer(base_channels * 8, blocks=layers[3], stride=2) self.out_channels = [ base_channels, base_channels * 2, base_channels * 4, base_channels * 8 ] def _make_layer(self, planes, blocks, stride): layers = [] layers.append(BasicBlock(self.inplanes, planes, stride=stride)) self.inplanes = planes for _ in range(1, blocks): layers.append(BasicBlock(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): x = self.relu(self.bn1(self.conv1(x))) x = self.maxpool(x) c2 = self.layer1(x) # 1/4 c3 = self.layer2(c2) # 1/8 c4 = self.layer3(c3) # 1/16 c5 = self.layer4(c4) # 1/32 return [c2, c3, c4, c5] class Conv1x1(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Conv2d(in_ch, out_ch, 1, bias=False) def forward(self, x): return self.conv(x) class Conv3x3(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False) def forward(self, x): return self.conv(x) class FPNC(nn.Module): def __init__( self, in_channels_list, lateral_channels=256, out_channels=64, conv_after_concat=False ): super().__init__() self.num_outs = len(in_channels_list) self.conv_after_concat = conv_after_concat self.lateral_convs = nn.ModuleList( [Conv1x1(c, lateral_channels) for c in in_channels_list] ) self.smooth_convs = nn.ModuleList( [Conv3x3(lateral_channels, out_channels) for _ in range(self.num_outs)] ) if conv_after_concat: self.out_conv = nn.Conv2d( out_channels * self.num_outs, out_channels * self.num_outs, kernel_size=3, stride=1, padding=1, bias=True, ) def forward(self, inputs): # inputs: [c2, c3, c4, c5] laterals = [l_conv(x) for x, l_conv in zip(inputs, self.lateral_convs)] # top-down fusion for i in range(self.num_outs - 1, 0, -1): prev_shape = laterals[i - 1].shape[2:] laterals[i - 1] = laterals[i - 1] + F.interpolate( laterals[i], size=prev_shape, mode="nearest" ) outs = [s_conv(layer) for s_conv, layer in zip(self.smooth_convs, laterals)] size = outs[0].shape[2:] outs = [F.interpolate(o, size=size, mode="nearest") for o in outs] out = torch.cat(outs, dim=1) if self.conv_after_concat: out = self.out_conv(out) return out class DBHead(nn.Module): def __init__(self, in_channels, k=50): super().__init__() self.k = k hidden = in_channels // 4 def up_branch(): return nn.Sequential( nn.Conv2d(in_channels, hidden, 3, padding=1, bias=False), nn.BatchNorm2d(hidden), nn.ReLU(inplace=True), nn.ConvTranspose2d(hidden, hidden, 2, stride=2, bias=True), nn.BatchNorm2d(hidden), nn.ReLU(inplace=True), nn.ConvTranspose2d(hidden, 1, 2, stride=2, bias=True), nn.Sigmoid(), ) self.binarize = up_branch() self.threshold = up_branch() def diff_bin(self, p, t): return torch.reciprocal(1.0 + torch.exp(-self.k * (p - t))) def forward(self, x): p = self.binarize(x) # [B,1,H,W] shrink map t = self.threshold(x) # [B,1,H,W] thresh map b = self.diff_bin(p, t) return torch.cat([p, t, b], dim=1) class DBNetForTextDetection(PreTrainedModel): config_class = DBNetConfig def __init__(self, config): super().__init__(config) self.backbone = ResNetBackbone( in_channels=config.in_channels, base_channels=config.backbone_base_channels, layers=config.backbone_layers ) self.neck = FPNC( in_channels_list=self.backbone.out_channels, lateral_channels=config.neck_lateral_channels, out_channels=config.neck_out_channels, conv_after_concat=False, # Match ocr_worker configuration ) # Calculate head input channels automatically neck_output = config.neck_out_channels * 4 self.det_head = DBHead(in_channels=neck_output, k=config.k) def forward(self, pixel_values): x = pixel_values feats = self.backbone(x) fpn_out = self.neck(feats) return self.det_head(fpn_out)