# modeling_patent.py import torch import torch.nn as nn from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM from transformers import PreTrainedModel, PretrainedConfig class PatentClassifierConfig(PretrainedConfig): model_type = "patent_classifier" # 必须与config.json中的model_type一致 def __init__(self, model_name="Qwen/Qwen3-0.6B", hidden_dims=[512, 256], output_dim=9, dropout_rate=0.1, max_length=256,** kwargs): super().__init__(**kwargs) self.model_name = model_name self.hidden_dims = hidden_dims self.output_dim = output_dim self.dropout_rate = dropout_rate self.max_length = max_length class PatentClassifier(PreTrainedModel): config_class = PatentClassifierConfig def __init__(self, config): super().__init__(config) self.config = config # 加载基础模型 if "qwen" in config.model_name.lower(): self.base_llm_model = AutoModelForCausalLM.from_pretrained( config.model_name, trust_remote_code=True ) else: self.base_llm_model = AutoModel.from_pretrained(config.model_name) # 固定预训练模型参数 for param in self.base_llm_model.parameters(): param.requires_grad = False # 添加MLP分类头 self.hidden_size = self.base_llm_model.config.hidden_size layers = [] input_dim = self.hidden_size for dim in config.hidden_dims: layers.append(nn.Linear(input_dim, dim)) layers.append(nn.ReLU()) layers.append(nn.Dropout(config.dropout_rate)) input_dim = dim layers.append(nn.Linear(input_dim, config.output_dim)) self.classifier = nn.Sequential(*layers) # 加载分词器 self.tokenizer = AutoTokenizer.from_pretrained(config.model_name) def forward(self, input_ids, attention_mask): with torch.no_grad(): outputs = self.base_llm_model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True ) last_hidden_state = outputs.hidden_states[-1] attention_mask = attention_mask.unsqueeze(-1) weighted_hidden = last_hidden_state * attention_mask cls_embedding = weighted_hidden.sum(dim=1) / attention_mask.sum(dim=1).clamp(min=1e-9) return self.classifier(cls_embedding) def tokenize(self, texts, max_length=None): max_length = max_length or self.config.max_length return self.tokenizer( texts, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" )