rahul7star commited on
Commit
08a95f3
·
verified ·
1 Parent(s): b06df19

Update app_quant.py

Browse files
Files changed (1) hide show
  1. app_quant.py +167 -161
app_quant.py CHANGED
@@ -1,186 +1,192 @@
1
- # ---------------------------------------------------------
2
- # Nava Ultra-Fast CPU Inference (4-bit Quant + Caching)
3
- # ---------------------------------------------------------
4
  import gradio as gr
5
  import torch
6
  import soundfile as sf
7
  from pathlib import Path
 
 
8
 
9
- from transformers import (
10
- AutoTokenizer,
11
- AutoModelForCausalLM,
12
- BitsAndBytesConfig
13
- )
14
- from peft import PeftModel
15
  from snac import SNAC
16
 
17
- # ---------------------------------------------------------
18
- # CONFIG
19
- # ---------------------------------------------------------
20
  MODEL_NAME = "rahul7star/nava1.0"
21
  LORA_NAME = "rahul7star/nava-audio"
22
  SNAC_MODEL_NAME = "rahul7star/nava-snac"
23
 
24
-
25
-
26
  TARGET_SR = 24000
27
- DEFAULT_BATCH_SIZE = 500
28
- MICRO_BATCH = 2
29
- SEQ_LEN = 2048
30
  OUT_ROOT = Path("/tmp/data")
31
  OUT_ROOT.mkdir(exist_ok=True, parents=True)
32
 
33
- DEFAULT_TEXT = (
34
- "राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से "
35
- "निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी"
36
- )
37
-
38
- DEVICE = "cpu"
39
-
40
- # ---------------------------------------------------------
41
- # QUANT CONFIG (4-BIT)
42
- # ---------------------------------------------------------
43
- quant_config = BitsAndBytesConfig(
44
- load_in_4bit=True,
45
- bnb_4bit_quant_type="nf4",
46
- bnb_4bit_use_double_quant=True,
47
- bnb_4bit_compute_dtype=torch.bfloat16,
48
- )
49
-
50
- # ---------------------------------------------------------
51
- # LOAD TOKENIZER (cached)
52
- # ---------------------------------------------------------
53
- print("🔄 Loading tokenizer...")
54
- tokenizer = AutoTokenizer.from_pretrained(
55
- MODEL_NAME,
56
- trust_remote_code=True
57
- )
58
-
59
- # ---------------------------------------------------------
60
- # LOAD BASE MODEL (4-bit CPU quant)
61
- # ---------------------------------------------------------
62
- print("🔄 Loading base model in 4-bit…")
63
- base_model = AutoModelForCausalLM.from_pretrained(
64
- MODEL_NAME,
65
- quantization_config=quant_config,
66
- device_map={"": DEVICE},
67
- torch_dtype=torch.bfloat16,
68
- trust_remote_code=True
69
- )
70
-
71
- # ---------------------------------------------------------
72
- # LOAD LORA (merged on top)
73
- # ---------------------------------------------------------
74
- print("🔄 Loading LoRA weights…")
75
- model = PeftModel.from_pretrained(
76
- base_model,
77
- LORA_NAME,
78
- device_map={"": DEVICE}
79
- ).eval()
80
-
81
- # ---------------------------------------------------------
82
- # LOAD SNAC ONCE ONLY
83
- # ---------------------------------------------------------
84
- print("🔄 Loading SNAC…")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE)
 
86
 
87
-
88
- # =========================================================
89
- # INFERENCE FUNCTION
90
- # =========================================================
91
- def generate_audio_cpu_lora(text):
92
-
93
  logs = []
94
- logs.append("⚡ Running fast 4-bit CPU inference…")
95
-
96
- # Tokens
97
- soh = tokenizer.decode([128259])
98
- eoh = tokenizer.decode([128260])
99
- soa = tokenizer.decode([128261])
100
- sos = tokenizer.decode([128257])
101
- eot = tokenizer.decode([128009])
102
- bos = tokenizer.bos_token
103
-
104
- prompt = soh + bos + text + eot + eoh + soa + sos
105
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
106
-
107
- # -----------------------------------------------------
108
- # GENERATE SNAC TOKENS (FAST 4-bit)
109
- # -----------------------------------------------------
110
- with torch.inference_mode():
111
- outputs = model.generate(
112
- **inputs,
113
- max_new_tokens=SEQ_LEN,
114
- temperature=0.4,
115
- top_p=0.9,
116
- repetition_penalty=1.1,
117
- do_sample=True,
118
- eos_token_id=128258,
119
- pad_token_id=tokenizer.pad_token_id
120
- )
121
-
122
- # Strip prompt
123
- gen_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist()
124
-
125
- # Extract valid SNAC tokens
126
- snac_min, snac_max = 128266, 156937
127
- eos_id = 128258
128
- eos_idx = gen_ids.index(eos_id) if eos_id in gen_ids else len(gen_ids)
129
-
130
- snac_tokens = [t for t in gen_ids[:eos_idx] if snac_min <= t <= snac_max]
131
-
132
- # -----------------------------------------------------
133
- # DECODE SNAC → AUDIO
134
- # -----------------------------------------------------
135
- l1, l2, l3 = [], [], []
136
- frames = len(snac_tokens) // 7
137
- snac_tokens = snac_tokens[:frames * 7]
138
-
139
- for i in range(frames):
140
- s = snac_tokens[i * 7:(i + 1) * 7]
141
- l1.append((s[0] - snac_min) % 4096)
142
- l2.extend([(s[1]-snac_min)%4096, (s[4]-snac_min)%4096])
143
- l3.extend([(s[2]-snac_min)%4096, (s[3]-snac_min)%4096,
144
- (s[5]-snac_min)%4096, (s[6]-snac_min)%4096])
145
-
146
- codes = [
147
- torch.tensor(l1).unsqueeze(0),
148
- torch.tensor(l2).unsqueeze(0),
149
- torch.tensor(l3).unsqueeze(0)
150
- ]
151
-
152
- with torch.inference_mode():
153
- z = snac_model.quantizer.from_codes(codes)
154
- audio = snac_model.decoder(z)[0, 0].cpu().numpy()
155
-
156
- # Remove crackles
157
- if len(audio) > 2048:
158
- audio = audio[2048:]
159
-
160
- # Save WAV
161
- out = OUT_ROOT / "tts_output_cpu_lora.wav"
162
- sf.write(out, audio, TARGET_SR)
163
-
164
- logs.append("🎧 Audio generated successfully")
165
-
166
- return str(out), str(out), "\n".join(logs)
167
-
168
-
169
- # =========================================================
170
- # GRADIO UI
171
- # =========================================================
172
  with gr.Blocks() as demo:
173
- gr.Markdown("## Maya TTS — Ultra-Fast 4-bit CPU Inference")
174
-
175
- txt = gr.Textbox(label="Enter text", value=DEFAULT_TEXT)
176
  btn = gr.Button("Generate Audio")
177
-
178
  audio = gr.Audio(label="Audio", type="filepath")
179
  file = gr.File(label="Download")
180
- logs = gr.Textbox(label="Logs")
181
-
182
- btn.click(generate_audio_cpu_lora, [txt], [audio, file, logs])
183
-
184
 
185
  if __name__ == "__main__":
186
  demo.launch()
 
1
+ # app_quant_fixed.py
 
 
2
  import gradio as gr
3
  import torch
4
  import soundfile as sf
5
  from pathlib import Path
6
+ import traceback
7
+ import time
8
 
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from peft import PeftModel, LoraConfig
 
 
 
 
11
  from snac import SNAC
12
 
13
+ # -------------------------
 
 
14
  MODEL_NAME = "rahul7star/nava1.0"
15
  LORA_NAME = "rahul7star/nava-audio"
16
  SNAC_MODEL_NAME = "rahul7star/nava-snac"
17
 
 
 
18
  TARGET_SR = 24000
 
 
 
19
  OUT_ROOT = Path("/tmp/data")
20
  OUT_ROOT.mkdir(exist_ok=True, parents=True)
21
 
22
+ DEFAULT_TEXT = "राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी"
23
+
24
+ # conservative defaults
25
+ SEQ_LEN_GPU = 240000 # if you really have GPU
26
+ SEQ_LEN_CPU = 4096 # keep CPU small to avoid OOM
27
+ MAX_NEW_TOKENS_CPU = 1024
28
+ MAX_NEW_TOKENS_GPU = 240000
29
+
30
+ # detect device
31
+ HAS_CUDA = torch.cuda.is_available()
32
+ DEVICE = "cuda" if HAS_CUDA else "cpu"
33
+
34
+ # optional: try import bitsandbytes only if CUDA available
35
+ try:
36
+ if HAS_CUDA:
37
+ from transformers import BitsAndBytesConfig
38
+ bnb_available = True
39
+ else:
40
+ bnb_available = False
41
+ except Exception:
42
+ bnb_available = False
43
+
44
+ print(f"[init] CUDA available: {HAS_CUDA}, bitsandbytes available: {bnb_available}")
45
+
46
+ # -------------------------
47
+ # Load tokenizer (always)
48
+ # -------------------------
49
+ print("[init] Loading tokenizer...")
50
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
51
+
52
+ # -------------------------
53
+ # Load base model & LoRA (GPU vs CPU safe)
54
+ # -------------------------
55
+ print("[init] Loading base model + LoRA (this may take a while)...")
56
+ if HAS_CUDA and bnb_available:
57
+ # GPU + bnb path: use 4-bit quant
58
+ quant_config = BitsAndBytesConfig(
59
+ load_in_4bit=True,
60
+ bnb_4bit_quant_type="nf4",
61
+ bnb_4bit_use_double_quant=True,
62
+ bnb_4bit_compute_dtype=torch.bfloat16
63
+ )
64
+ base_model = AutoModelForCausalLM.from_pretrained(
65
+ MODEL_NAME,
66
+ quantization_config=quant_config,
67
+ device_map="auto",
68
+ trust_remote_code=True,
69
+ )
70
+ model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map="auto")
71
+ SEQ_LEN = SEQ_LEN_GPU
72
+ MAX_NEW_TOKENS = MAX_NEW_TOKENS_GPU
73
+ print("[init] Loaded model in 4-bit (GPU).")
74
+ else:
75
+ # CPU fallback: load in FP32 with low_cpu_mem_usage
76
+ # Avoid load_in_4bit on CPU
77
+ base_model = AutoModelForCausalLM.from_pretrained(
78
+ MODEL_NAME,
79
+ torch_dtype=torch.float32,
80
+ device_map={"": "cpu"},
81
+ low_cpu_mem_usage=True,
82
+ trust_remote_code=True,
83
+ )
84
+ # attach PEFT adapter - this will add LoRA wrappers but keep base weights on CPU
85
+ model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map={"": "cpu"})
86
+ SEQ_LEN = SEQ_LEN_CPU
87
+ MAX_NEW_TOKENS = MAX_NEW_TOKENS_CPU
88
+ print("[init] Loaded model on CPU (FP32) with LoRA.")
89
+
90
+ model.eval()
91
+
92
+ # -------------------------
93
+ # Load SNAC (once)
94
+ # -------------------------
95
+ print("[init] Loading SNAC...")
96
  snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE)
97
+ print("[init] SNAC loaded.")
98
 
99
+ # -------------------------
100
+ # Inference function
101
+ # -------------------------
102
+ def generate_audio_cpu_lora(text: str):
 
 
103
  logs = []
104
+ t0 = time.time()
105
+ try:
106
+ logs.append(f"[INFO] Device: {DEVICE} | SEQ_LEN: {SEQ_LEN} | MAX_NEW_TOKENS: {MAX_NEW_TOKENS}")
107
+
108
+ # Build prompt (same as your earlier code)
109
+ soh = tokenizer.decode([128259]); eoh = tokenizer.decode([128260])
110
+ soa = tokenizer.decode([128261]); sos = tokenizer.decode([128257])
111
+ eot = tokenizer.decode([128009])
112
+ bos = tokenizer.bos_token
113
+ prompt = soh + bos + text + eot + eoh + soa + sos
114
+
115
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
116
+
117
+ # Keep generated tokens small on CPU
118
+ max_new = min(MAX_NEW_TOKENS, 1024) if DEVICE == "cpu" else MAX_NEW_TOKENS
119
+
120
+ with torch.inference_mode():
121
+ outputs = model.generate(
122
+ **inputs,
123
+ max_new_tokens=max_new,
124
+ temperature=0.4,
125
+ top_p=0.9,
126
+ repetition_penalty=1.1,
127
+ do_sample=True,
128
+ eos_token_id=128258,
129
+ pad_token_id=tokenizer.pad_token_id
130
+ )
131
+
132
+ # extract generated part
133
+ gen_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist()
134
+ logs.append(f"[INFO] Generated {len(gen_ids)} tokens")
135
+
136
+ # filter SNAC tokens (same logic)
137
+ snac_min, snac_max = 128266, 156937
138
+ eos_id = 128258
139
+ eos_idx = gen_ids.index(eos_id) if eos_id in gen_ids else len(gen_ids)
140
+ snac_tokens = [t for t in gen_ids[:eos_idx] if snac_min <= t <= snac_max]
141
+
142
+ frames = len(snac_tokens) // 7
143
+ snac_tokens = snac_tokens[:frames*7]
144
+
145
+ l1, l2, l3 = [], [], []
146
+ for i in range(frames):
147
+ s = snac_tokens[i*7:(i+1)*7]
148
+ l1.append((s[0]-snac_min) % 4096)
149
+ l2.extend([(s[1]-snac_min)%4096, (s[4]-snac_min)%4096])
150
+ l3.extend([(s[2]-snac_min)%4096, (s[3]-snac_min)%4096, (s[5]-snac_min)%4096, (s[6]-snac_min)%4096])
151
+
152
+ if len(l1) == 0:
153
+ logs.append("[WARN] No SNAC frames found in generated tokens. Returning debug logs.")
154
+ return None, None, "\n".join(logs)
155
+
156
+ codes_tensor = [torch.tensor(l1, dtype=torch.long, device=DEVICE).unsqueeze(0),
157
+ torch.tensor(l2, dtype=torch.long, device=DEVICE).unsqueeze(0),
158
+ torch.tensor(l3, dtype=torch.long, device=DEVICE).unsqueeze(0)]
159
+
160
+ with torch.inference_mode():
161
+ z_q = snac_model.quantizer.from_codes(codes_tensor)
162
+ audio = snac_model.decoder(z_q)[0,0].cpu().numpy()
163
+
164
+ if len(audio) > 2048:
165
+ audio = audio[2048:]
166
+
167
+ out_path = OUT_ROOT / f"tts_output_cpu_lora.wav"
168
+ sf.write(out_path, audio, TARGET_SR)
169
+ logs.append(f"[OK] Audio saved: {out_path} (duration {len(audio)/TARGET_SR:.2f}s)")
170
+
171
+ logs.append(f"[TIME] Elapsed {time.time()-t0:.2f}s")
172
+ return str(out_path), str(out_path), "\n".join(logs)
173
+
174
+ except Exception as e:
175
+ tb = traceback.format_exc()
176
+ logs.append(f"[ERROR] {e}\n{tb}")
177
+ return None, None, "\n".join(logs)
178
+
179
+ # -------------------------
180
+ # Gradio UI
181
+ # -------------------------
182
  with gr.Blocks() as demo:
183
+ gr.Markdown("## Maya TTS — CPU/GPU safe")
184
+ txt = gr.Textbox(label="Enter text", value=DEFAULT_TEXT, lines=2)
 
185
  btn = gr.Button("Generate Audio")
 
186
  audio = gr.Audio(label="Audio", type="filepath")
187
  file = gr.File(label="Download")
188
+ logs_box = gr.Textbox(label="Logs", lines=10)
189
+ btn.click(generate_audio_cpu_lora, [txt], [audio, file, logs_box])
 
 
190
 
191
  if __name__ == "__main__":
192
  demo.launch()