oswaldoludwig commited on
Commit
6a08381
·
verified ·
1 Parent(s): b4b1908

Upload 2 files

Browse files
complete_example_use_selective_fine_tuning.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script exemplify a selective fine-tuning method based on the condition number using freely available data and LLM
2
+ # Author: Oswaldo Ludwig (now with AI support)
3
+ # Date: 03/07/2025
4
+ # In case of publication using this script or ideas in this script, cite:
5
+ # Ludwig, Oswaldo. "The Condition Number as a Scale-Invariant Proxy for Information Encoding in Neural Units." arXiv preprint arXiv:2506.16289 (2025).
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
+ from datasets import load_dataset
13
+ from selective_fine_tuning import SelectiveFineTuningOptimizer
14
+
15
+ # Dataset using AG News
16
+ class RealTextDataset(Dataset):
17
+ def __init__(self, tokenizer, split='train', max_samples=200, seq_len=64):
18
+ dataset = load_dataset("ag_news", split=split)
19
+ self.samples = dataset.select(range(max_samples))
20
+ self.tokenizer = tokenizer
21
+ self.seq_len = seq_len
22
+
23
+ def __len__(self):
24
+ return len(self.samples)
25
+
26
+ def __getitem__(self, idx):
27
+ text = self.samples[idx]['text']
28
+ encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.seq_len, return_tensors='pt')
29
+ input_ids = encoding['input_ids'].squeeze(0)
30
+ return input_ids, input_ids.clone()
31
+
32
+ # Training loop
33
+ def train():
34
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+ model = AutoModelForCausalLM.from_pretrained(model_name)
37
+ print("Tokenizer and LLM loaded", flush=True)
38
+
39
+ dataset = RealTextDataset(tokenizer=tokenizer, max_samples=200, seq_len=64)
40
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
41
+ print("Data loader and dataset loaded", flush=True)
42
+
43
+ criterion = nn.CrossEntropyLoss()
44
+
45
+ optimizer_wrapper = SelectiveFineTuningOptimizer(
46
+ model=model,
47
+ base_optimizer_cls=optim.AdamW,
48
+ optimizer_args={'lr': 5e-5},
49
+ condition_file='condition_numbers.json',
50
+ num_tensors_to_finetune=20,
51
+ recompute=True
52
+ )
53
+ print("Optimizer instantiated", flush=True)
54
+
55
+ model.train()
56
+ for epoch in range(3):
57
+ total_loss = 0
58
+ for inputs, targets in dataloader:
59
+ optimizer_wrapper.zero_grad()
60
+ outputs = model(inputs, labels=targets)
61
+ loss = outputs.loss
62
+ loss.backward()
63
+ optimizer_wrapper.step()
64
+ total_loss += loss.item()
65
+ print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")
66
+
67
+ if __name__ == '__main__':
68
+ train()
selective_fine_tuning.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The class in this script implements a selective fine-tuning method based on the condition number
2
+ # Author: Oswaldo Ludwig (now with AI support)
3
+ # Date: 03/07/2025
4
+ # In case of publication using this script or ideas in this script, cite:
5
+ # Ludwig, Oswaldo. "The Condition Number as a Scale-Invariant Proxy for Information Encoding in Neural Units." arXiv preprint arXiv:2506.16289 (2025).
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+ import os
11
+ import json
12
+ import numpy as np
13
+ import logging
14
+ from typing import Type, Dict, Any, Set, List
15
+
16
+ # Configure logging (ensure this is at the top level or configured once)
17
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class SelectiveFineTuningOptimizer:
21
+ """
22
+ A custom optimizer wrapper that selectively fine-tunes a PyTorch model
23
+ based on the condition numbers of its parameters. Parameters with lower
24
+ condition numbers are prioritized for fine-tuning.
25
+ """
26
+ def __init__(self, model: nn.Module, base_optimizer_cls: Type[optim.Optimizer], optimizer_args: Dict[str, Any],
27
+ condition_file: str = 'condition_numbers.json',
28
+ num_tensors_to_finetune: int = 100,
29
+ recompute: bool = False,
30
+ max_dim_size_to_analyze: int = None): # New parameter for filtering
31
+ """
32
+ Initializes the SelectiveFineTuningOptimizer.
33
+
34
+ Args:
35
+ model (nn.Module): The PyTorch model to be fine-tuned.
36
+ base_optimizer_cls (Type[optim.Optimizer]): The class of the base optimizer (e.g., torch.optim.Adam).
37
+ optimizer_args (Dict[str, Any]): A dictionary of arguments to pass to the base optimizer constructor.
38
+ condition_file (str): Path to the JSON file for storing/loading condition numbers.
39
+ num_tensors_to_finetune (int): The number of top tensors (based on condition number) to fine-tune.
40
+ recompute (bool): If True, recompute condition numbers even if the file exists.
41
+ max_dim_size_to_analyze (int, optional): If provided, any parameter tensor with at least one dimension
42
+ larger than this value will be skipped from analysis.
43
+ Useful for ignoring very large embedding matrices etc.
44
+ """
45
+ self.model = model
46
+ self.condition_file = condition_file
47
+ self.num_tensors_to_finetune = num_tensors_to_finetune
48
+ self.recompute = recompute
49
+ self.max_dim_size_to_analyze = max_dim_size_to_analyze # Store the new parameter
50
+
51
+ self.condition_numbers: Dict[str, float] = {}
52
+
53
+ if not os.path.exists(condition_file) or recompute:
54
+ self.condition_numbers = self._analyze_model()
55
+ self._save_condition_numbers()
56
+ else:
57
+ self.condition_numbers = self._load_condition_numbers()
58
+
59
+ self.trainable_param_names: Set[str] = self._select_trainable_parameters()
60
+ self._unfreeze_selected_parameters()
61
+
62
+ # Initialize the base optimizer with selected parameters
63
+ params_to_optimize = [p for n, p in model.named_parameters() if n in self.trainable_param_names]
64
+ if not params_to_optimize:
65
+ logger.warning("No parameters selected for fine-tuning based on the criteria. Optimizer will have no parameters.")
66
+ self.optimizer = base_optimizer_cls(params_to_optimize, **optimizer_args)
67
+ logger.info(f"Optimizer initialized with {len(params_to_optimize)} trainable parameters.")
68
+
69
+
70
+ def _analyze_model(self) -> Dict[str, float]:
71
+ """
72
+ Analyzes the singular values of model parameters to compute their condition numbers.
73
+ Parameters with less than 2 dimensions or having any dimension
74
+ larger than `max_dim_size_to_analyze` are ignored.
75
+ SVD is performed on the GPU if the tensor is on CUDA, otherwise on CPU.
76
+
77
+ Returns:
78
+ Dict[str, float]: A dictionary mapping parameter names to their condition numbers.
79
+ """
80
+ condition_numbers = {}
81
+ logger.info("Analyzing the model tensors...")
82
+
83
+ initial_requires_grad_state = {}
84
+ for name, param in self.model.named_parameters():
85
+ initial_requires_grad_state[name] = param.requires_grad
86
+ param.requires_grad = False # Temporarily disable for analysis
87
+
88
+ analyzed_count = 0
89
+ skipped_ndim_count = 0
90
+ skipped_dim_size_count = 0 # New counter
91
+ skipped_svd_error_count = 0
92
+ total_params_in_model = 0
93
+
94
+ try:
95
+ for name, param in self.model.named_parameters():
96
+ total_params_in_model += 1
97
+ # Filter 1: Skip by number of dimensions
98
+ if param.ndim < 2:
99
+ logger.debug(f"Skipping {name} due to less than 2 dimensions (ndim={param.ndim}).")
100
+ skipped_ndim_count += 1
101
+ continue
102
+ # Filter 2: Skip by any dimension size exceeding threshold
103
+ if self.max_dim_size_to_analyze is not None:
104
+ if any(dim_size > self.max_dim_size_to_analyze for dim_size in param.shape):
105
+ logger.debug(f"Skipping {name} due to a dimension larger than {self.max_dim_size_to_analyze} (shape={param.shape}).")
106
+ skipped_dim_size_count += 1
107
+ continue
108
+
109
+ try:
110
+ data = param.detach() # Keep on GPU if already there
111
+ if data.is_cuda:
112
+ # Perform SVD on GPU
113
+ u, s, v = torch.linalg.svd(data, full_matrices=False)
114
+ else:
115
+ # Fallback to CPU if not on CUDA
116
+ u, s, v = torch.linalg.svd(data.cpu(), full_matrices=False)
117
+
118
+ cond_number = (s[0] / s[-1]).item() if s[-1] > 0 else float('inf')
119
+ condition_numbers[name] = cond_number
120
+ analyzed_count += 1
121
+ logger.debug(f"Analyzed {name}: condition_number={cond_number:.4f}")
122
+ except torch.linalg.LinAlgError as e:
123
+ logger.warning(f"Skipping {name} due to SVD Linear Algebra error: {e}")
124
+ skipped_svd_error_count += 1
125
+ except Exception as e:
126
+ logger.error(f"Skipping {name} due to unexpected error during SVD: {e}")
127
+ skipped_svd_error_count += 1
128
+ finally:
129
+ # Restore initial requires_grad state (though _unfreeze_selected_parameters will override this)
130
+ for name, param in self.model.named_parameters():
131
+ param.requires_grad = initial_requires_grad_state[name]
132
+
133
+
134
+ logger.info(f"Done analyzing model tensors. Total parameters in model: {total_params_in_model}")
135
+ logger.info(f"Parameters analyzed for condition numbers: {analyzed_count}")
136
+ logger.info(f"Skipped due to ndim < 2: {skipped_ndim_count}")
137
+ logger.info(f"Skipped due to dimension size > {self.max_dim_size_to_analyze}: {skipped_dim_size_count}") # New log
138
+ logger.info(f"Skipped due to SVD errors: {skipped_svd_error_count}")
139
+ return condition_numbers
140
+
141
+ def _save_condition_numbers(self):
142
+ """
143
+ Saves the computed condition numbers to a JSON file.
144
+ """
145
+ try:
146
+ with open(self.condition_file, 'w') as f:
147
+ json.dump(self.condition_numbers, f, indent=2)
148
+ logger.info(f"Condition numbers saved to {self.condition_file}")
149
+ except IOError as e:
150
+ logger.error(f"Failed to save condition numbers to {self.condition_file}: {e}")
151
+
152
+ def _load_condition_numbers(self) -> Dict[str, float]:
153
+ """
154
+ Loads condition numbers from a JSON file. If the file is corrupted,
155
+ it triggers a recomputation.
156
+
157
+ Returns:
158
+ Dict[str, float]: The loaded condition numbers.
159
+ """
160
+ try:
161
+ with open(self.condition_file, 'r') as f:
162
+ data = json.load(f)
163
+ logger.info(f"Condition numbers loaded from {self.condition_file}")
164
+ return data
165
+ except json.JSONDecodeError as e:
166
+ logger.warning(f"Condition file '{self.condition_file}' is corrupted or invalid. Error: {e}. Recomputing.")
167
+ if os.path.exists(self.condition_file):
168
+ try:
169
+ os.remove(self.condition_file) # Remove corrupted file
170
+ logger.info(f"Removed corrupted condition file: {self.condition_file}")
171
+ except OSError as err:
172
+ logger.error(f"Error removing corrupted file {self.condition_file}: {err}")
173
+ return self._analyze_model() # Recompute if loading fails
174
+ except IOError as e:
175
+ logger.error(f"Failed to load condition numbers from {self.condition_file}: {e}. Recomputing.")
176
+ return self._analyze_model() # Recompute if file not found or other IO error
177
+
178
+ def _select_trainable_parameters(self) -> Set[str]:
179
+ """
180
+ Selects the top `num_tensors_to_finetune` parameters based on their condition numbers
181
+ (lower condition number is better).
182
+
183
+ Returns:
184
+ Set[str]: A set of names of the parameters chosen for fine-tuning.
185
+ """
186
+ if not self.condition_numbers:
187
+ logger.warning("No condition numbers available to select trainable parameters.")
188
+ return set()
189
+
190
+ sorted_params = sorted(self.condition_numbers.items(), key=lambda x: x[1])
191
+ selected = [name for name, _ in sorted_params[:self.num_tensors_to_finetune]]
192
+ logger.info(f"Selected {len(selected)} parameters for fine-tuning out of {len(self.condition_numbers)} analyzed.")
193
+ logger.debug(f"Selected parameters: {selected}")
194
+ return set(selected)
195
+
196
+ def _unfreeze_selected_parameters(self):
197
+ """
198
+ Sets `requires_grad=True` for the selected trainable parameters
199
+ and `requires_grad=False` for all other parameters in the model.
200
+ """
201
+ total_params = 0
202
+ frozen_params_count = 0
203
+ unfrozen_params_count = 0
204
+
205
+ for name, param in self.model.named_parameters():
206
+ total_params += 1
207
+ if name in self.trainable_param_names:
208
+ if not param.requires_grad: # Only change if it's different
209
+ param.requires_grad = True
210
+ unfrozen_params_count += 1
211
+ logger.debug(f"Parameter '{name}' set to requires_grad=True.")
212
+ else:
213
+ if param.requires_grad: # Only change if it's different
214
+ param.requires_grad = False
215
+ frozen_params_count += 1
216
+ logger.debug(f"Parameter '{name}' set to requires_grad=False.")
217
+
218
+ logger.info(f"Model parameters configured: {unfrozen_params_count} unfrozen, {frozen_params_count} frozen (out of {total_params} total).")
219
+
220
+
221
+ def step(self):
222
+ """
223
+ Performs a single optimization step (parameter update).
224
+ Delegates to the base optimizer's step method.
225
+ """
226
+ self.optimizer.step()
227
+
228
+ def zero_grad(self):
229
+ """
230
+ Clears the gradients of all optimized parameters.
231
+ Delegates to the base optimizer's zero_grad method.
232
+ """
233
+ self.optimizer.zero_grad()
234
+
235
+ def state_dict(self) -> Dict[str, Any]:
236
+ """
237
+ Returns a serializable dictionary containing the current state of the optimizer.
238
+ Delegates to the base optimizer's state_dict method.
239
+ """
240
+ return self.optimizer.state_dict()
241
+
242
+ def load_state_dict(self, state_dict: Dict[str, Any]):
243
+ """
244
+ Loads the optimizer's state from a state_dict.
245
+ Delegates to the base optimizer's load_state_dict method.
246
+
247
+ Args:
248
+ state_dict (Dict[str, Any]): A dictionary containing the optimizer's state.
249
+ """
250
+ self.optimizer.load_state_dict(state_dict)
251
+