chore: Refactor TTS functionality and dependencies
Browse files- kitt/core/tts.py +0 -28
- main.py +1 -4
kitt/core/tts.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
| 1 |
import copy
|
| 2 |
from collections import namedtuple
|
| 3 |
|
| 4 |
-
import soundfile as sf
|
| 5 |
import torch
|
| 6 |
from loguru import logger
|
| 7 |
-
from parler_tts import ParlerTTSForConditionalGeneration
|
| 8 |
from replicate import Client
|
| 9 |
-
from transformers import AutoTokenizer
|
| 10 |
|
| 11 |
from kitt.skills.common import config
|
| 12 |
|
|
@@ -94,31 +91,6 @@ def run_tts_replicate(text: str, voice_character: str):
|
|
| 94 |
return output
|
| 95 |
|
| 96 |
|
| 97 |
-
def get_fast_tts():
|
| 98 |
-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 99 |
-
|
| 100 |
-
model = ParlerTTSForConditionalGeneration.from_pretrained(
|
| 101 |
-
"parler-tts/parler-tts-mini-expresso"
|
| 102 |
-
).to(device)
|
| 103 |
-
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
|
| 104 |
-
return model, tokenizer, device
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
fast_tts = get_fast_tts()
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def run_tts_fast(text: str):
|
| 111 |
-
model, tokenizer, device = fast_tts
|
| 112 |
-
description = "Thomas speaks moderately slowly in a sad tone with emphasis and high quality audio."
|
| 113 |
-
|
| 114 |
-
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
|
| 115 |
-
prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
|
| 116 |
-
|
| 117 |
-
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
|
| 118 |
-
audio_arr = generation.cpu().numpy().squeeze()
|
| 119 |
-
return (model.config.sampling_rate, audio_arr), dict(text=text, voice="Thomas")
|
| 120 |
-
|
| 121 |
-
|
| 122 |
def load_melo_tts():
|
| 123 |
from melo.api import TTS as MeloTTS
|
| 124 |
|
|
|
|
| 1 |
import copy
|
| 2 |
from collections import namedtuple
|
| 3 |
|
|
|
|
| 4 |
import torch
|
| 5 |
from loguru import logger
|
|
|
|
| 6 |
from replicate import Client
|
|
|
|
| 7 |
|
| 8 |
from kitt.skills.common import config
|
| 9 |
|
|
|
|
| 91 |
return output
|
| 92 |
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
def load_melo_tts():
|
| 95 |
from melo.api import TTS as MeloTTS
|
| 96 |
|
main.py
CHANGED
|
@@ -9,7 +9,7 @@ from kitt.core import utils as kitt_utils
|
|
| 9 |
from kitt.core import voice_options
|
| 10 |
from kitt.core.model import generate_function_call as process_query
|
| 11 |
from kitt.core.stt import save_and_transcribe_audio
|
| 12 |
-
from kitt.core.tts import prep_for_tts, run_melo_tts,
|
| 13 |
from kitt.skills import (
|
| 14 |
code_interpreter,
|
| 15 |
date_time_info,
|
|
@@ -118,9 +118,6 @@ def run_llama3_model(query, voice_character, state):
|
|
| 118 |
voice_out = tts_gradio(
|
| 119 |
output_text_tts, voice_character, speaker_embedding_cache
|
| 120 |
)[0]
|
| 121 |
-
#
|
| 122 |
-
# voice_out = run_tts_fast(output_text)[0]
|
| 123 |
-
#
|
| 124 |
return (
|
| 125 |
output_text,
|
| 126 |
voice_out,
|
|
|
|
| 9 |
from kitt.core import voice_options
|
| 10 |
from kitt.core.model import generate_function_call as process_query
|
| 11 |
from kitt.core.stt import save_and_transcribe_audio
|
| 12 |
+
from kitt.core.tts import prep_for_tts, run_melo_tts, run_tts_replicate
|
| 13 |
from kitt.skills import (
|
| 14 |
code_interpreter,
|
| 15 |
date_time_info,
|
|
|
|
| 118 |
voice_out = tts_gradio(
|
| 119 |
output_text_tts, voice_character, speaker_embedding_cache
|
| 120 |
)[0]
|
|
|
|
|
|
|
|
|
|
| 121 |
return (
|
| 122 |
output_text,
|
| 123 |
voice_out,
|