rahul7star commited on
Commit
9f11255
·
verified ·
1 Parent(s): 55a1499

Update app_quant.py

Browse files
Files changed (1) hide show
  1. app_quant.py +105 -84
app_quant.py CHANGED
@@ -1,4 +1,10 @@
1
- # app.py
 
 
 
 
 
 
2
  import gradio as gr
3
  import torch
4
  import soundfile as sf
@@ -12,19 +18,18 @@ from peft import PeftModel
12
  from snac import SNAC
13
 
14
  # -------------------------
15
- # Config / constants
16
  # -------------------------
17
- MODEL_NAME = "rahul7star/nava1.0" # base maya model (your variant)
18
- LORA_NAME = "rahul7star/nava-audio" # your LoRA adapter
19
- SNAC_MODEL_NAME = "rahul7star/nava-snac" # snac decoder (use hub model id)
20
  TARGET_SR = 24000
21
  OUT_ROOT = Path("/tmp/data")
22
  OUT_ROOT.mkdir(exist_ok=True, parents=True)
23
 
24
  DEFAULT_TEXT = "राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी"
25
- EXAMPLE_AUDIO_PATH = "audio.wav" # file in repo root, user-supplied
26
 
27
- # Preset characters (2 realistic + 2 creative + Custom)
28
  PRESET_CHARACTERS = {
29
  "Male American": {
30
  "description": "Realistic male voice in the 20s age with an american accent. High pitch, raspy timbre, brisk pacing, neutral tone delivery at medium intensity, viral_content domain, short_form_narrator role, neutral delivery",
@@ -43,29 +48,28 @@ PRESET_CHARACTERS = {
43
  "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is a viable course of action. <chuckle> Why would we ever consider running away very fast."
44
  },
45
  "Custom": {
46
- "description": "", # user will edit
47
  "example_text": DEFAULT_TEXT
48
  }
49
  }
50
 
51
- # Emotion tags (full list you asked to support)
52
  EMOTION_TAGS = [
53
  "<neutral>", "<angry>", "<chuckle>", "<cry>", "<disappointed>",
54
  "<excited>", "<gasp>", "<giggle>", "<laugh>", "<laugh_harder>",
55
  "<sarcastic>", "<sigh>", "<sing>", "<whisper>"
56
  ]
57
 
58
- # Short safety / generation limits
59
  SEQ_LEN_CPU = 4096
60
  MAX_NEW_TOKENS_CPU = 1024
61
  SEQ_LEN_GPU = 240000
62
  MAX_NEW_TOKENS_GPU = 240000
63
 
64
- # Detect devices
65
  HAS_CUDA = torch.cuda.is_available()
66
  DEVICE = "cuda" if HAS_CUDA else "cpu"
67
 
68
- # Try to detect bitsandbytes availability for faster GPU inference (4-bit)
69
  bnb_available = False
70
  if HAS_CUDA:
71
  try:
@@ -77,14 +81,30 @@ if HAS_CUDA:
77
  print(f"[init] cuda={HAS_CUDA}, bnb={bnb_available}, device={DEVICE}")
78
 
79
  # -------------------------
80
- # Load tokenizer + model + LoRA + SNAC ONCE (startup)
81
  # -------------------------
82
  print("[init] loading tokenizer...")
83
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
84
 
85
- print("[init] loading base model + LoRA adapter (this can take time)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if HAS_CUDA and bnb_available:
87
- # GPU + bnb path (fastest inference if available)
88
  quant_config = BitsAndBytesConfig(
89
  load_in_4bit=True,
90
  bnb_4bit_quant_type="nf4",
@@ -100,7 +120,7 @@ if HAS_CUDA and bnb_available:
100
  model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map="auto")
101
  SEQ_LEN = SEQ_LEN_GPU
102
  MAX_NEW_TOKENS = MAX_NEW_TOKENS_GPU
103
- print("[init] loaded base+LoRA on GPU (4-bit via bnb).")
104
  else:
105
  # CPU fallback - load base into CPU memory and attach LoRA
106
  base_model = AutoModelForCausalLM.from_pretrained(
@@ -115,49 +135,58 @@ else:
115
  MAX_NEW_TOKENS = MAX_NEW_TOKENS_CPU
116
  print("[init] loaded base+LoRA on CPU (FP32).")
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  model.eval()
119
  print("[init] model ready.")
120
 
121
- print("[init] loading SNAC decoder...")
122
- snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE)
 
 
 
 
123
  print("[init] snac ready.")
124
 
125
- # --------------
126
- # Helper: build prompt per Maya conventions
127
- # --------------
 
 
128
  def build_maya_prompt(description: str, text: str):
129
- # use the special tokens used by maya-style models
130
- soh_token = tokenizer.decode([128259]) # SOH
131
- eoh_token = tokenizer.decode([128260]) # EOH
132
- soa_token = tokenizer.decode([128261]) # SOA
133
- sos_token = tokenizer.decode([128257]) # SOS (code start)
134
- eot_token = tokenizer.decode([128009]) # TEXT_EOT / EOT marker
135
- bos_token = tokenizer.bos_token
136
-
137
- # We use the simple format: "<description> <text>" and Maya wrappers
138
  formatted = f'<description="{description}"> {text}'
139
- prompt = soh_token + bos_token + formatted + eot_token + eoh_token + soa_token + sos_token
140
- return prompt
141
 
142
- # --------------
143
- # Core generate function (uses preloaded model & snac)
144
- # --------------
145
- def generate_from_loaded_model(final_text: str):
146
- """
147
- final_text: text that already contains description + emotion + user text
148
- returns: (audio_path_str, download_path_str, logs_str)
149
- """
150
  logs = []
151
  t0 = time.time()
152
  try:
153
- logs.append(f"[info] device={DEVICE} | seq_len={SEQ_LEN}")
154
-
155
- prompt = final_text
156
- inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
157
 
158
- max_new = MAX_NEW_TOKENS if DEVICE == "cuda" else min(MAX_NEW_TOKENS, 1024)
 
 
 
 
159
 
160
- # Use inference_mode for speed
161
  with torch.inference_mode():
162
  outputs = model.generate(
163
  **inputs,
@@ -168,13 +197,13 @@ def generate_from_loaded_model(final_text: str):
168
  do_sample=True,
169
  eos_token_id=128258,
170
  pad_token_id=tokenizer.pad_token_id,
 
171
  )
172
 
173
- # Grab generated ids (after prompt length)
174
  gen_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist()
175
  logs.append(f"[info] generated tokens: {len(gen_ids)}")
176
 
177
- # Extract SNAC tokens (range used by Maya/SNAC)
178
  SNAC_MIN = 128266
179
  SNAC_MAX = 156937
180
  EOS_ID = 128258
@@ -182,38 +211,39 @@ def generate_from_loaded_model(final_text: str):
182
  snac_tokens = [t for t in gen_ids[:eos_idx] if SNAC_MIN <= t <= SNAC_MAX]
183
 
184
  frames = len(snac_tokens) // 7
185
- snac_tokens = snac_tokens[:frames*7]
186
 
187
- if frames == 0 or len(snac_tokens) == 0:
188
- logs.append("[warn] no SNAC frames found in generated tokens — returning debug logs.")
189
  return None, None, "\n".join(logs)
190
 
191
- # De-interleave into l1, l2, l3
192
  l1, l2, l3 = [], [], []
193
  for i in range(frames):
194
- s = snac_tokens[i*7:(i+1)*7]
195
  l1.append((s[0] - SNAC_MIN) % 4096)
196
  l2.extend([(s[1] - SNAC_MIN) % 4096, (s[4] - SNAC_MIN) % 4096])
197
  l3.extend([(s[2] - SNAC_MIN) % 4096, (s[3] - SNAC_MIN) % 4096, (s[5] - SNAC_MIN) % 4096, (s[6] - SNAC_MIN) % 4096])
198
 
199
- # Convert to tensors on decoder device and decode
200
  codes_tensor = [
201
- torch.tensor(l1, dtype=torch.long, device=DEVICE).unsqueeze(0),
202
- torch.tensor(l2, dtype=torch.long, device=DEVICE).unsqueeze(0),
203
- torch.tensor(l3, dtype=torch.long, device=DEVICE).unsqueeze(0),
204
  ]
205
 
 
206
  with torch.inference_mode():
207
  z_q = snac_model.quantizer.from_codes(codes_tensor)
208
  audio = snac_model.decoder(z_q)[0, 0].cpu().numpy()
209
 
210
- # Remove warmup if present and save
211
  if len(audio) > 2048:
212
  audio = audio[2048:]
213
 
214
- out_path = OUT_ROOT / "tts_output_loaded_lora.wav"
215
  sf.write(out_path, audio, TARGET_SR)
216
- logs.append(f"[ok] saved {out_path} duration={(len(audio)/TARGET_SR):.2f}s")
217
  logs.append(f"[time] elapsed {time.time() - t0:.2f}s")
218
 
219
  return str(out_path), str(out_path), "\n".join(logs)
@@ -223,37 +253,29 @@ def generate_from_loaded_model(final_text: str):
223
  logs.append(f"[error] {e}\n{tb}")
224
  return None, None, "\n".join(logs)
225
 
 
226
  # --------------
227
- # UI glue: combine description + emotion + user text (3a)
228
  # --------------
229
  def generate_for_ui(text, preset_name, description, emotion):
230
- logs = []
231
- try:
232
- # If user selected a preset, and description param is empty (e.g. custom not edited),
233
- # take preset description
234
- if preset_name in PRESET_CHARACTERS and (not description or description.strip() == ""):
235
- description = PRESET_CHARACTERS[preset_name]["description"]
236
-
237
- # combine (3a): final_text = f"{emotion} {description}. {text}"
238
- # For Maya prompt, we pass the combined description+text to build_maya_prompt
239
- combined_desc = f"{emotion} {description}".strip()
240
- final_plain = f"{combined_desc}. {text}".strip()
241
- final_prompt = build_maya_prompt(combined_desc, text) # keep maya wrapper
242
-
243
- audio_path, download_path, gen_logs = generate_from_loaded_model(final_prompt)
244
- if audio_path is None:
245
- return None, None, gen_logs
246
- return audio_path, download_path, gen_logs
247
 
248
- except Exception as e:
249
- return None, None, f"[error] {e}\n{traceback.format_exc()}"
250
 
251
  # -------------------------
252
- # Gradio UI (keeps your layout; wide container)
253
  # -------------------------
254
  css = ".gradio-container {max-width: 1400px}"
255
- with gr.Blocks(title="NAVA — VEEN + LoRA + SNAC (Optimized)", css=css) as demo:
256
- gr.Markdown("# 🪶 NAVA — VEEN + LoRA + SNAC (Optimized)\nGenerate emotional Hindi speech using Maya1 base + your LoRA adapter.")
257
  with gr.Row():
258
  with gr.Column(scale=3):
259
  gr.Markdown("## Inference (CPU/GPU auto)\nType text + pick a preset or write description manually.")
@@ -284,6 +306,5 @@ with gr.Blocks(title="NAVA — VEEN + LoRA + SNAC (Optimized)", css=css) as demo
284
  inputs=[text_in, preset_select, description_box, emotion_select],
285
  outputs=[audio_player, download_file, gen_logs])
286
 
287
- # -------------------------
288
  if __name__ == "__main__":
289
  demo.launch()
 
1
+ # app_optimized.py
2
+ """
3
+ Optimized inference for Maya1 + LoRA + SNAC.
4
+ Keeps your UI unchanged; replaces internal model loading + generate paths
5
+ to run much faster (preload everything, SNAC on GPU when available, reuse tokens).
6
+ """
7
+
8
  import gradio as gr
9
  import torch
10
  import soundfile as sf
 
18
  from snac import SNAC
19
 
20
  # -------------------------
21
+ # Config / constants (same as you)
22
  # -------------------------
23
+ MODEL_NAME = "rahul7star/nava1.0"
24
+ LORA_NAME = "rahul7star/nava-audio"
25
+ SNAC_MODEL_NAME = "hubertsiuzdak/snac_24khz" # decoder
26
  TARGET_SR = 24000
27
  OUT_ROOT = Path("/tmp/data")
28
  OUT_ROOT.mkdir(exist_ok=True, parents=True)
29
 
30
  DEFAULT_TEXT = "राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी"
31
+ EXAMPLE_AUDIO_PATH = "audio.wav"
32
 
 
33
  PRESET_CHARACTERS = {
34
  "Male American": {
35
  "description": "Realistic male voice in the 20s age with an american accent. High pitch, raspy timbre, brisk pacing, neutral tone delivery at medium intensity, viral_content domain, short_form_narrator role, neutral delivery",
 
48
  "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is a viable course of action. <chuckle> Why would we ever consider running away very fast."
49
  },
50
  "Custom": {
51
+ "description": "",
52
  "example_text": DEFAULT_TEXT
53
  }
54
  }
55
 
 
56
  EMOTION_TAGS = [
57
  "<neutral>", "<angry>", "<chuckle>", "<cry>", "<disappointed>",
58
  "<excited>", "<gasp>", "<giggle>", "<laugh>", "<laugh_harder>",
59
  "<sarcastic>", "<sigh>", "<sing>", "<whisper>"
60
  ]
61
 
62
+ # length limits
63
  SEQ_LEN_CPU = 4096
64
  MAX_NEW_TOKENS_CPU = 1024
65
  SEQ_LEN_GPU = 240000
66
  MAX_NEW_TOKENS_GPU = 240000
67
 
68
+ # detect device
69
  HAS_CUDA = torch.cuda.is_available()
70
  DEVICE = "cuda" if HAS_CUDA else "cpu"
71
 
72
+ # try bitsandbytes for faster GPU (optional)
73
  bnb_available = False
74
  if HAS_CUDA:
75
  try:
 
81
  print(f"[init] cuda={HAS_CUDA}, bnb={bnb_available}, device={DEVICE}")
82
 
83
  # -------------------------
84
+ # Load tokenizer and model + LoRA once at startup (optimized)
85
  # -------------------------
86
  print("[init] loading tokenizer...")
87
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
88
 
89
+ # Precompute commonly used special tokens (avoid repeated decode calls)
90
+ SOH = tokenizer.decode([128259])
91
+ EOH = tokenizer.decode([128260])
92
+ SOA = tokenizer.decode([128261])
93
+ SOS = tokenizer.decode([128257])
94
+ EOT = tokenizer.decode([128009])
95
+ BOS = tokenizer.bos_token
96
+
97
+ # Optionally compile model later if torch>=2 and CPU path (safe-guarded)
98
+ enable_torch_compile = False
99
+ try:
100
+ if not HAS_CUDA and hasattr(torch, "compile"):
101
+ enable_torch_compile = True
102
+ except Exception:
103
+ enable_torch_compile = False
104
+
105
+ print("[init] loading base model + LoRA (this may take time)...")
106
  if HAS_CUDA and bnb_available:
107
+ # GPU + bnb path (fastest if available)
108
  quant_config = BitsAndBytesConfig(
109
  load_in_4bit=True,
110
  bnb_4bit_quant_type="nf4",
 
120
  model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map="auto")
121
  SEQ_LEN = SEQ_LEN_GPU
122
  MAX_NEW_TOKENS = MAX_NEW_TOKENS_GPU
123
+ print("[init] loaded base+LoRA on GPU (4-bit).")
124
  else:
125
  # CPU fallback - load base into CPU memory and attach LoRA
126
  base_model = AutoModelForCausalLM.from_pretrained(
 
135
  MAX_NEW_TOKENS = MAX_NEW_TOKENS_CPU
136
  print("[init] loaded base+LoRA on CPU (FP32).")
137
 
138
+ # Ensure cache usage
139
+ try:
140
+ model.config.use_cache = True
141
+ except Exception:
142
+ pass
143
+
144
+ # Optionally compile model for faster CPU (if available and tested)
145
+ if enable_torch_compile:
146
+ try:
147
+ print("[init] compiling model (torch.compile)...")
148
+ model = torch.compile(model)
149
+ except Exception as e:
150
+ print("[init] torch.compile failed, continuing without it:", e)
151
+
152
  model.eval()
153
  print("[init] model ready.")
154
 
155
+ # -------------------------
156
+ # Load SNAC decoder once (prefer GPU device for decoder)
157
+ # -------------------------
158
+ snac_device = DEVICE if HAS_CUDA else "cpu"
159
+ print(f"[init] loading SNAC decoder onto {snac_device} ...")
160
+ snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(snac_device)
161
  print("[init] snac ready.")
162
 
163
+ # Optional: if you have an upsampler like in your FastAudioSR path, plug it here (omitted for portability)
164
+
165
+ # -------------------------
166
+ # Helper: build Maya-style prompt (reusing tokens above)
167
+ # -------------------------
168
  def build_maya_prompt(description: str, text: str):
 
 
 
 
 
 
 
 
 
169
  formatted = f'<description="{description}"> {text}'
170
+ # use precomputed tokens for speed
171
+ return SOH + BOS + formatted + EOT + EOH + SOA + SOS
172
 
173
+ # -------------------------
174
+ # Optimized generator: reuse tokenizer/model/snac in memory
175
+ # -------------------------
176
+ def generate_from_loaded_model(final_prompt: str, max_new_tokens_override: int = None):
 
 
 
 
177
  logs = []
178
  t0 = time.time()
179
  try:
180
+ # tokenise WITHOUT adding extra padding if not needed
181
+ inputs = tokenizer(final_prompt, return_tensors="pt", truncation=True).to(DEVICE)
 
 
182
 
183
+ # choose new-token budget
184
+ if max_new_tokens_override is not None:
185
+ max_new = max_new_tokens_override
186
+ else:
187
+ max_new = MAX_NEW_TOKENS if DEVICE == "cuda" else min(MAX_NEW_TOKENS, 1024)
188
 
189
+ # Use inference_mode (fast) and use_cache (set earlier)
190
  with torch.inference_mode():
191
  outputs = model.generate(
192
  **inputs,
 
197
  do_sample=True,
198
  eos_token_id=128258,
199
  pad_token_id=tokenizer.pad_token_id,
200
+ use_cache=True,
201
  )
202
 
 
203
  gen_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist()
204
  logs.append(f"[info] generated tokens: {len(gen_ids)}")
205
 
206
+ # Extract SNAC tokens
207
  SNAC_MIN = 128266
208
  SNAC_MAX = 156937
209
  EOS_ID = 128258
 
211
  snac_tokens = [t for t in gen_ids[:eos_idx] if SNAC_MIN <= t <= SNAC_MAX]
212
 
213
  frames = len(snac_tokens) // 7
214
+ snac_tokens = snac_tokens[:frames * 7]
215
 
216
+ if frames == 0:
217
+ logs.append("[warn] no SNAC frames found")
218
  return None, None, "\n".join(logs)
219
 
220
+ # de-interleave
221
  l1, l2, l3 = [], [], []
222
  for i in range(frames):
223
+ s = snac_tokens[i * 7:(i + 1) * 7]
224
  l1.append((s[0] - SNAC_MIN) % 4096)
225
  l2.extend([(s[1] - SNAC_MIN) % 4096, (s[4] - SNAC_MIN) % 4096])
226
  l3.extend([(s[2] - SNAC_MIN) % 4096, (s[3] - SNAC_MIN) % 4096, (s[5] - SNAC_MIN) % 4096, (s[6] - SNAC_MIN) % 4096])
227
 
228
+ # move codes to decoder device (snac_device)
229
  codes_tensor = [
230
+ torch.tensor(l1, dtype=torch.long, device=snac_device).unsqueeze(0),
231
+ torch.tensor(l2, dtype=torch.long, device=snac_device).unsqueeze(0),
232
+ torch.tensor(l3, dtype=torch.long, device=snac_device).unsqueeze(0),
233
  ]
234
 
235
+ # decode to audio on SNAC device
236
  with torch.inference_mode():
237
  z_q = snac_model.quantizer.from_codes(codes_tensor)
238
  audio = snac_model.decoder(z_q)[0, 0].cpu().numpy()
239
 
240
+ # remove warmup region
241
  if len(audio) > 2048:
242
  audio = audio[2048:]
243
 
244
+ out_path = OUT_ROOT / "tts_output_optimized.wav"
245
  sf.write(out_path, audio, TARGET_SR)
246
+ logs.append(f"[ok] saved {out_path} duration {len(audio)/TARGET_SR:.2f}s")
247
  logs.append(f"[time] elapsed {time.time() - t0:.2f}s")
248
 
249
  return str(out_path), str(out_path), "\n".join(logs)
 
253
  logs.append(f"[error] {e}\n{tb}")
254
  return None, None, "\n".join(logs)
255
 
256
+
257
  # --------------
258
+ # UI glue (keeps your layout EXACTLY)
259
  # --------------
260
  def generate_for_ui(text, preset_name, description, emotion):
261
+ # choose preset description if blank
262
+ if preset_name in PRESET_CHARACTERS and (not description or description.strip() == ""):
263
+ description = PRESET_CHARACTERS[preset_name]["description"]
264
+
265
+ # combine (3a): final_text = f"{emotion} {description}. {text}"
266
+ combined_desc = f"{emotion} {description}".strip()
267
+ final_prompt = build_maya_prompt(combined_desc, text)
268
+
269
+ # call optimized generator
270
+ return generate_from_loaded_model(final_prompt)
 
 
 
 
 
 
 
271
 
 
 
272
 
273
  # -------------------------
274
+ # Gradio UI (unchanged UI layout)
275
  # -------------------------
276
  css = ".gradio-container {max-width: 1400px}"
277
+ with gr.Blocks(title="NAVA — Maya1 + LoRA + SNAC (Optimized)", css=css) as demo:
278
+ gr.Markdown("# 🪶 NAVA — Maya1 + LoRA + SNAC (Optimized)\nGenerate emotional Hindi speech using Maya1 base + your LoRA adapter.")
279
  with gr.Row():
280
  with gr.Column(scale=3):
281
  gr.Markdown("## Inference (CPU/GPU auto)\nType text + pick a preset or write description manually.")
 
306
  inputs=[text_in, preset_select, description_box, emotion_select],
307
  outputs=[audio_player, download_file, gen_logs])
308
 
 
309
  if __name__ == "__main__":
310
  demo.launch()