Small updates
Browse files- kitt/core/__init__.py +2 -2
- main.py +20 -14
kitt/core/__init__.py
CHANGED
|
@@ -101,7 +101,7 @@ def speed_from_text(voice):
|
|
| 101 |
return v.speed
|
| 102 |
|
| 103 |
|
| 104 |
-
def
|
| 105 |
self,
|
| 106 |
text: str = "",
|
| 107 |
language_name: str = "",
|
|
@@ -198,7 +198,7 @@ def tts_gradio(text, voice, cache):
|
|
| 198 |
(gpt_cond_latent, speaker_embedding) = compute_speaker_embedding(
|
| 199 |
voice_path, tts_pipeline.synthesizer.tts_config, tts_pipeline, cache
|
| 200 |
)
|
| 201 |
-
out =
|
| 202 |
tts_pipeline.synthesizer,
|
| 203 |
text,
|
| 204 |
language_name="en",
|
|
|
|
| 101 |
return v.speed
|
| 102 |
|
| 103 |
|
| 104 |
+
def tts_xtts(
|
| 105 |
self,
|
| 106 |
text: str = "",
|
| 107 |
language_name: str = "",
|
|
|
|
| 198 |
(gpt_cond_latent, speaker_embedding) = compute_speaker_embedding(
|
| 199 |
voice_path, tts_pipeline.synthesizer.tts_config, tts_pipeline, cache
|
| 200 |
)
|
| 201 |
+
out = tts_xtts(
|
| 202 |
tts_pipeline.synthesizer,
|
| 203 |
text,
|
| 204 |
language_name="en",
|
main.py
CHANGED
|
@@ -40,7 +40,7 @@ from kitt.skills.routing import calculate_route, find_address
|
|
| 40 |
|
| 41 |
ORIGIN = "Mondorf-les-Bains, Luxembourg"
|
| 42 |
DESTINATION = "Rue Alphonse Weicker, Luxembourg"
|
| 43 |
-
DEFAULT_LLM_BACKEND = "
|
| 44 |
ENABLE_HISTORY = True
|
| 45 |
ENABLE_TTS = True
|
| 46 |
TTS_BACKEND = "local"
|
|
@@ -133,11 +133,11 @@ def search_along_route(query=""):
|
|
| 133 |
|
| 134 |
def set_time(time_picker):
|
| 135 |
vehicle.time = time_picker
|
| 136 |
-
return vehicle.model_dump_json()
|
| 137 |
|
| 138 |
|
| 139 |
def get_vehicle_status(state):
|
| 140 |
-
return state.value["vehicle"].model_dump_json()
|
| 141 |
|
| 142 |
|
| 143 |
tools = [
|
|
@@ -232,11 +232,16 @@ def run_llama3_model(query, voice_character, state):
|
|
| 232 |
)
|
| 233 |
gr.Info(f"Output text: {output_text}\nGenerating voice output...")
|
| 234 |
voice_out = None
|
| 235 |
-
if
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
# voice_out = run_tts_fast(output_text)[0]
|
| 238 |
-
|
| 239 |
-
# voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
|
| 240 |
return (
|
| 241 |
output_text,
|
| 242 |
voice_out,
|
|
@@ -264,7 +269,7 @@ def run_model(query, voice_character, state):
|
|
| 264 |
return (
|
| 265 |
text,
|
| 266 |
voice,
|
| 267 |
-
vehicle
|
| 268 |
state,
|
| 269 |
dict(update_proxy=global_context["update_proxy"]),
|
| 270 |
)
|
|
@@ -299,7 +304,8 @@ def update_vehicle_status(trip_progress, origin, destination, state):
|
|
| 299 |
plot = kitt_utils.plot_route(
|
| 300 |
global_context["route_points"], vehicle=vehicle.location_coordinates
|
| 301 |
)
|
| 302 |
-
return vehicle
|
|
|
|
| 303 |
|
| 304 |
|
| 305 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -335,8 +341,8 @@ def save_and_transcribe_audio(audio):
|
|
| 335 |
gr.Info(f"Transcribed text is: {text}\nProcessing the input...")
|
| 336 |
|
| 337 |
except Exception as e:
|
| 338 |
-
|
| 339 |
-
|
| 340 |
return text
|
| 341 |
|
| 342 |
|
|
@@ -447,6 +453,9 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
| 447 |
|
| 448 |
with gr.Row():
|
| 449 |
with gr.Column(scale=1, min_width=300):
|
|
|
|
|
|
|
|
|
|
| 450 |
time_picker = gr.Dropdown(
|
| 451 |
choices=hour_options,
|
| 452 |
label="What time is it? (HH:MM)",
|
|
@@ -516,9 +525,6 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
| 516 |
value=dict(update_proxy=0),
|
| 517 |
label="Global context",
|
| 518 |
)
|
| 519 |
-
vehicle_status = gr.JSON(
|
| 520 |
-
value=vehicle.model_dump_json(), label="Vehicle status"
|
| 521 |
-
)
|
| 522 |
with gr.Accordion("Config"):
|
| 523 |
tts_enabled = gr.Radio(
|
| 524 |
["Yes", "No"],
|
|
|
|
| 40 |
|
| 41 |
ORIGIN = "Mondorf-les-Bains, Luxembourg"
|
| 42 |
DESTINATION = "Rue Alphonse Weicker, Luxembourg"
|
| 43 |
+
DEFAULT_LLM_BACKEND = "replicate"
|
| 44 |
ENABLE_HISTORY = True
|
| 45 |
ENABLE_TTS = True
|
| 46 |
TTS_BACKEND = "local"
|
|
|
|
| 133 |
|
| 134 |
def set_time(time_picker):
|
| 135 |
vehicle.time = time_picker
|
| 136 |
+
return vehicle.model_dump_json(indent=2)
|
| 137 |
|
| 138 |
|
| 139 |
def get_vehicle_status(state):
|
| 140 |
+
return state.value["vehicle"].model_dump_json(indent=2)
|
| 141 |
|
| 142 |
|
| 143 |
tools = [
|
|
|
|
| 232 |
)
|
| 233 |
gr.Info(f"Output text: {output_text}\nGenerating voice output...")
|
| 234 |
voice_out = None
|
| 235 |
+
if global_context["tts_enabled"]:
|
| 236 |
+
if "Fast" in voice_character:
|
| 237 |
+
voice_out = run_melo_tts(output_text, voice_character)
|
| 238 |
+
elif global_context["tts_backend"] == "replicate":
|
| 239 |
+
voice_out = run_tts_replicate(output_text, voice_character)
|
| 240 |
+
else:
|
| 241 |
+
voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
|
| 242 |
+
#
|
| 243 |
# voice_out = run_tts_fast(output_text)[0]
|
| 244 |
+
#
|
|
|
|
| 245 |
return (
|
| 246 |
output_text,
|
| 247 |
voice_out,
|
|
|
|
| 269 |
return (
|
| 270 |
text,
|
| 271 |
voice,
|
| 272 |
+
vehicle,
|
| 273 |
state,
|
| 274 |
dict(update_proxy=global_context["update_proxy"]),
|
| 275 |
)
|
|
|
|
| 304 |
plot = kitt_utils.plot_route(
|
| 305 |
global_context["route_points"], vehicle=vehicle.location_coordinates
|
| 306 |
)
|
| 307 |
+
return vehicle, plot, state
|
| 308 |
+
return vehicle.model_dump_json(indent=2), plot, state
|
| 309 |
|
| 310 |
|
| 311 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 341 |
gr.Info(f"Transcribed text is: {text}\nProcessing the input...")
|
| 342 |
|
| 343 |
except Exception as e:
|
| 344 |
+
logger.error(f"Error: {e}")
|
| 345 |
+
raise Exception("Error transcribing audio.")
|
| 346 |
return text
|
| 347 |
|
| 348 |
|
|
|
|
| 453 |
|
| 454 |
with gr.Row():
|
| 455 |
with gr.Column(scale=1, min_width=300):
|
| 456 |
+
vehicle_status = gr.JSON(
|
| 457 |
+
value=vehicle.model_dump_json(indent=2), label="Vehicle status"
|
| 458 |
+
)
|
| 459 |
time_picker = gr.Dropdown(
|
| 460 |
choices=hour_options,
|
| 461 |
label="What time is it? (HH:MM)",
|
|
|
|
| 525 |
value=dict(update_proxy=0),
|
| 526 |
label="Global context",
|
| 527 |
)
|
|
|
|
|
|
|
|
|
|
| 528 |
with gr.Accordion("Config"):
|
| 529 |
tts_enabled = gr.Radio(
|
| 530 |
["Yes", "No"],
|