Sifal commited on
Commit
2bfa274
·
verified ·
1 Parent(s): 18e6b12

Delete bert_classification.py

Browse files
Files changed (1) hide show
  1. bert_classification.py +0 -135
bert_classification.py DELETED
@@ -1,135 +0,0 @@
1
- import logging
2
- from typing import Optional, Tuple, Union
3
-
4
- import torch
5
- import torch.nn as nn
6
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
- from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
8
- from transformers import BertPreTrainedModel
9
- from transformers.modeling_outputs import SequenceClassifierOutput
10
-
11
- from bert_layers_mosa import BertModel
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- class MosaicBertForSequenceClassification(BertPreTrainedModel):
17
- """Bert Model transformer with a sequence classification/regression head.
18
-
19
- This head is just a linear layer on top of the pooled output.
20
- """
21
-
22
- def __init__(self, config):
23
- super().__init__(config)
24
- self.num_labels = config.num_labels
25
- self.config = config
26
- self.bert = BertModel(config, add_pooling_layer=True)
27
- classifier_dropout = (
28
- config.classifier_dropout
29
- if config.classifier_dropout is not None
30
- else config.hidden_dropout_prob
31
- )
32
- self.dropout = nn.Dropout(classifier_dropout)
33
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
34
-
35
- # this resets the weights
36
- self.post_init()
37
-
38
- @classmethod
39
- def from_pretrained(
40
- cls, pretrained_checkpoint, state_dict=None, config=None, *inputs, **kwargs
41
- ):
42
- """Load from pre-trained."""
43
- # this gets a fresh init model
44
- model = cls(config, *inputs, **kwargs)
45
-
46
- # thus we need to load the state_dict
47
- state_dict = torch.load(pretrained_checkpoint)
48
- # remove `model` prefix to avoid error
49
- consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
50
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
51
-
52
- if len(missing_keys) > 0:
53
- logger.warning(
54
- f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
55
- )
56
-
57
- logger.warning(f"the number of which is equal to {len(missing_keys)}")
58
-
59
- if len(unexpected_keys) > 0:
60
- logger.warning(
61
- f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
62
- )
63
- logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
64
-
65
- return model
66
-
67
- def forward(
68
- self,
69
- input_ids: Optional[torch.Tensor] = None,
70
- attention_mask: Optional[torch.Tensor] = None,
71
- token_type_ids: Optional[torch.Tensor] = None,
72
- position_ids: Optional[torch.Tensor] = None,
73
- head_mask: Optional[torch.Tensor] = None,
74
- inputs_embeds: Optional[torch.Tensor] = None,
75
- labels: Optional[torch.Tensor] = None,
76
- output_attentions: Optional[bool] = None,
77
- output_hidden_states: Optional[bool] = None,
78
- return_dict: Optional[bool] = None,
79
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
80
-
81
- return_dict = (
82
- return_dict if return_dict is not None else self.config.use_return_dict
83
- )
84
-
85
- outputs = self.bert(
86
- input_ids,
87
- attention_mask=attention_mask,
88
- token_type_ids=token_type_ids,
89
- position_ids=position_ids,
90
- head_mask=head_mask,
91
- inputs_embeds=inputs_embeds,
92
- output_attentions=output_attentions,
93
- output_hidden_states=output_hidden_states,
94
- return_dict=return_dict,
95
- )
96
-
97
- pooled_output = outputs[1]
98
-
99
- pooled_output = self.dropout(pooled_output)
100
- logits = self.classifier(pooled_output)
101
-
102
- loss = None
103
- if labels is not None:
104
- if self.config.problem_type is None:
105
- if self.num_labels == 1:
106
- self.config.problem_type = "regression"
107
- elif self.num_labels > 1 and (
108
- labels.dtype == torch.long or labels.dtype == torch.int
109
- ):
110
- self.config.problem_type = "single_label_classification"
111
- else:
112
- self.config.problem_type = "multi_label_classification"
113
-
114
- if self.config.problem_type == "regression":
115
- loss_fct = MSELoss()
116
- if self.num_labels == 1:
117
- loss = loss_fct(logits.squeeze(), labels.squeeze())
118
- else:
119
- loss = loss_fct(logits, labels)
120
- elif self.config.problem_type == "single_label_classification":
121
- loss_fct = CrossEntropyLoss()
122
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
123
- elif self.config.problem_type == "multi_label_classification":
124
- loss_fct = BCEWithLogitsLoss()
125
- loss = loss_fct(logits, labels)
126
- if not return_dict:
127
- output = (logits,) + outputs[2:]
128
- return ((loss,) + output) if loss is not None else output
129
-
130
- return SequenceClassifierOutput(
131
- loss=loss,
132
- logits=logits,
133
- hidden_states=None,
134
- attentions=None,
135
- )