from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2Model, Qwen2ForCausalLM from transformers.configuration_utils import PretrainedConfig import torch import torch.nn as nn class CustomQwen2DecoderLayer(Qwen2DecoderLayer): def __init__(self, config, layer_idx): super().__init__(config, layer_idx) self.register_buffer("resid_bias", torch.zeros(config.hidden_size), persistent=True) def forward(self, hidden_states, *args, **kwargs): outputs = super().forward(hidden_states, *args, **kwargs) if hasattr(self, "resid_bias") and self.resid_bias is not None: if isinstance(outputs, tuple): hidden = outputs[0] else: hidden = outputs bias = self.resid_bias.to(hidden.device).to(hidden.dtype) # BROADCASTING: Iron Wall applies to ALL tokens if bias.norm() > 0: hidden = hidden + bias.view(1, 1, -1) if isinstance(outputs, tuple): outputs = (hidden,) + outputs[1:] else: outputs = hidden return outputs class CustomQwen2Model(Qwen2Model): def __init__(self, config: PretrainedConfig): super().__init__(config) self.layers = nn.ModuleList([CustomQwen2DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) self.post_init() class CustomQwen2ForCausalLM(Qwen2ForCausalLM): def __init__(self, config): super().__init__(config) self.model = CustomQwen2Model(config) self.post_init()