Yi30 commited on
Commit
6ef035d
·
verified ·
1 Parent(s): b364156

Create convert_for_g2_draft.py

Browse files
Files changed (1) hide show
  1. convert_for_g2_draft.py +78 -0
convert_for_g2_draft.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from safetensors import safe_open
3
+ from safetensors.torch import save_file
4
+ import torch
5
+ from typing import Dict, Set
6
+
7
+
8
+ def get_tensors(file_path: str) -> Dict[str, torch.Tensor]:
9
+ tensors = {}
10
+ with safe_open(file_path, framework="pt", device="cpu") as f:
11
+ for k in f.keys():
12
+ tensors[k] = f.get_tensor(k)
13
+ return tensors
14
+
15
+
16
+ def get_quantized_modules(tensor_keys, keyword: str = "scale") -> Set[str]:
17
+ """
18
+ Extract module prefixes like model.layers.0.mlp.down_proj from keys like:
19
+ model.layers.0.mlp.down_proj.weight_scale
20
+ """
21
+ quantized_modules = set()
22
+ for key in tensor_keys:
23
+ if keyword in key:
24
+ parts = key.split(".")
25
+ if len(parts) >= 2:
26
+ mod_prefix = ".".join(parts[:-1])
27
+ quantized_modules.add(mod_prefix)
28
+ return quantized_modules
29
+
30
+
31
+ def modify_quantized_tensors(tensors: Dict[str, torch.Tensor], quantized_modules: Set[str]) -> Dict[str, torch.Tensor]:
32
+ """
33
+ For each quantized module:
34
+ - weight --> divide by 2
35
+ - weight_scale --> multiply by 2
36
+ - input_scale --> multiply by 2
37
+ """
38
+ modified = {}
39
+ for key, tensor in tensors.items():
40
+ modified_tensor = tensor
41
+ for mod in quantized_modules:
42
+ if key == f"{mod}.weight":
43
+ modified_tensor = (tensor.to(torch.float32) / 2).to(torch.float8_e4m3fn)
44
+ elif key == f"{mod}.weight_scale":
45
+ modified_tensor = tensor * 2
46
+ elif key == f"{mod}.input_scale":
47
+ modified_tensor = tensor * 2
48
+ modified[key] = modified_tensor
49
+ return modified
50
+
51
+
52
+ def process_folder(folder_path: str, output_folder: str):
53
+ os.makedirs(output_folder, exist_ok=True)
54
+
55
+ for file in os.listdir(folder_path):
56
+ if not file.endswith(".safetensors"):
57
+ continue
58
+
59
+ file_path = os.path.join(folder_path, file)
60
+ print(f"Processing: {file_path}")
61
+ try:
62
+ tensors = get_tensors(file_path)
63
+ quantized_modules = get_quantized_modules(tensors.keys())
64
+ modified_tensors = modify_quantized_tensors(tensors, quantized_modules)
65
+
66
+ output_path = os.path.join(output_folder, file)
67
+ save_file(modified_tensors, output_path)
68
+ print(f"Saved modified tensors to: {output_path}")
69
+
70
+ except Exception as e:
71
+ print(f"Failed to process {file_path}: {e}")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ input_folder = "/mnt/disk5/tencent/Hunyuan-7B-Instruct-FP8"
76
+ output_folder = "/mnt/disk5/tencent/Hunyuan-7B-Instruct-FP8-modified"
77
+
78
+ process_folder(input_folder, output_folder)