Sort out vehicle status
Browse files- kitt/core/model.py +1 -1
- kitt/skills/routing.py +1 -1
- main.py +14 -18
kitt/core/model.py
CHANGED
|
@@ -331,7 +331,7 @@ def run_inference_replicate(prompt):
|
|
| 331 |
)
|
| 332 |
out = "".join(output)
|
| 333 |
|
| 334 |
-
logger.debug(f"Response from
|
| 335 |
|
| 336 |
return out
|
| 337 |
|
|
|
|
| 331 |
)
|
| 332 |
out = "".join(output)
|
| 333 |
|
| 334 |
+
logger.debug(f"Response from Replicate:\nOut:{out}")
|
| 335 |
|
| 336 |
return out
|
| 337 |
|
kitt/skills/routing.py
CHANGED
|
@@ -59,7 +59,7 @@ def calculate_route(origin, destination):
|
|
| 59 |
data = response.json()
|
| 60 |
points = data["routes"][0]["legs"][0]["points"]
|
| 61 |
|
| 62 |
-
return vehicle
|
| 63 |
|
| 64 |
|
| 65 |
def find_route_tomtom(
|
|
|
|
| 59 |
data = response.json()
|
| 60 |
points = data["routes"][0]["legs"][0]["points"]
|
| 61 |
|
| 62 |
+
return vehicle, points
|
| 63 |
|
| 64 |
|
| 65 |
def find_route_tomtom(
|
main.py
CHANGED
|
@@ -133,11 +133,7 @@ def search_along_route(query=""):
|
|
| 133 |
|
| 134 |
def set_time(time_picker):
|
| 135 |
vehicle.time = time_picker
|
| 136 |
-
return vehicle
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def get_vehicle_status(state):
|
| 140 |
-
return state.value["vehicle"].model_dump_json(indent=2)
|
| 141 |
|
| 142 |
|
| 143 |
tools = [
|
|
@@ -238,10 +234,12 @@ def run_llama3_model(query, voice_character, state):
|
|
| 238 |
elif global_context["tts_backend"] == "replicate":
|
| 239 |
voice_out = run_tts_replicate(output_text, voice_character)
|
| 240 |
else:
|
| 241 |
-
voice_out = tts_gradio(
|
| 242 |
-
|
|
|
|
|
|
|
| 243 |
# voice_out = run_tts_fast(output_text)[0]
|
| 244 |
-
#
|
| 245 |
return (
|
| 246 |
output_text,
|
| 247 |
voice_out,
|
|
@@ -269,24 +267,24 @@ def run_model(query, voice_character, state):
|
|
| 269 |
return (
|
| 270 |
text,
|
| 271 |
voice,
|
| 272 |
-
vehicle,
|
| 273 |
state,
|
| 274 |
dict(update_proxy=global_context["update_proxy"]),
|
| 275 |
)
|
| 276 |
|
| 277 |
|
| 278 |
def calculate_route_gradio(origin, destination):
|
| 279 |
-
|
| 280 |
plot = kitt_utils.plot_route(points, vehicle=vehicle.location_coordinates)
|
| 281 |
global_context["route_points"] = points
|
| 282 |
# state.value["route_points"] = points
|
| 283 |
vehicle.location_coordinates = points[0]["latitude"], points[0]["longitude"]
|
| 284 |
-
return plot,
|
| 285 |
|
| 286 |
|
| 287 |
def update_vehicle_status(trip_progress, origin, destination, state):
|
| 288 |
if not global_context["route_points"]:
|
| 289 |
-
|
| 290 |
global_context["route_points"] = points
|
| 291 |
global_context["destination"] = destination
|
| 292 |
global_context["route_points"] = global_context["route_points"]
|
|
@@ -305,7 +303,6 @@ def update_vehicle_status(trip_progress, origin, destination, state):
|
|
| 305 |
global_context["route_points"], vehicle=vehicle.location_coordinates
|
| 306 |
)
|
| 307 |
return vehicle, plot, state
|
| 308 |
-
return vehicle.model_dump_json(indent=2), plot, state
|
| 309 |
|
| 310 |
|
| 311 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -415,9 +412,7 @@ def conditional_update():
|
|
| 415 |
or global_context["update_proxy"] == 0
|
| 416 |
):
|
| 417 |
logger.info(f"Updating the map plot... in conditional_update")
|
| 418 |
-
map_plot,
|
| 419 |
-
vehicle.location, vehicle.destination
|
| 420 |
-
)
|
| 421 |
global_context["map"] = map_plot
|
| 422 |
return global_context["map"]
|
| 423 |
|
|
@@ -448,13 +443,13 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
| 448 |
}
|
| 449 |
)
|
| 450 |
|
| 451 |
-
plot,
|
| 452 |
global_context["map"] = plot
|
| 453 |
|
| 454 |
with gr.Row():
|
| 455 |
with gr.Column(scale=1, min_width=300):
|
| 456 |
vehicle_status = gr.JSON(
|
| 457 |
-
value=vehicle.
|
| 458 |
)
|
| 459 |
time_picker = gr.Dropdown(
|
| 460 |
choices=hour_options,
|
|
@@ -649,6 +644,7 @@ demo.launch(
|
|
| 649 |
ssl_verify=False,
|
| 650 |
share=False,
|
| 651 |
)
|
|
|
|
| 652 |
app = typer.Typer()
|
| 653 |
|
| 654 |
|
|
|
|
| 133 |
|
| 134 |
def set_time(time_picker):
|
| 135 |
vehicle.time = time_picker
|
| 136 |
+
return vehicle
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
|
| 139 |
tools = [
|
|
|
|
| 234 |
elif global_context["tts_backend"] == "replicate":
|
| 235 |
voice_out = run_tts_replicate(output_text, voice_character)
|
| 236 |
else:
|
| 237 |
+
voice_out = tts_gradio(
|
| 238 |
+
output_text, voice_character, speaker_embedding_cache
|
| 239 |
+
)[0]
|
| 240 |
+
#
|
| 241 |
# voice_out = run_tts_fast(output_text)[0]
|
| 242 |
+
#
|
| 243 |
return (
|
| 244 |
output_text,
|
| 245 |
voice_out,
|
|
|
|
| 267 |
return (
|
| 268 |
text,
|
| 269 |
voice,
|
| 270 |
+
vehicle.model_dump(),
|
| 271 |
state,
|
| 272 |
dict(update_proxy=global_context["update_proxy"]),
|
| 273 |
)
|
| 274 |
|
| 275 |
|
| 276 |
def calculate_route_gradio(origin, destination):
|
| 277 |
+
_, points = calculate_route(origin, destination)
|
| 278 |
plot = kitt_utils.plot_route(points, vehicle=vehicle.location_coordinates)
|
| 279 |
global_context["route_points"] = points
|
| 280 |
# state.value["route_points"] = points
|
| 281 |
vehicle.location_coordinates = points[0]["latitude"], points[0]["longitude"]
|
| 282 |
+
return plot, vehicle.model_dump(), 0
|
| 283 |
|
| 284 |
|
| 285 |
def update_vehicle_status(trip_progress, origin, destination, state):
|
| 286 |
if not global_context["route_points"]:
|
| 287 |
+
_, points = calculate_route(origin, destination)
|
| 288 |
global_context["route_points"] = points
|
| 289 |
global_context["destination"] = destination
|
| 290 |
global_context["route_points"] = global_context["route_points"]
|
|
|
|
| 303 |
global_context["route_points"], vehicle=vehicle.location_coordinates
|
| 304 |
)
|
| 305 |
return vehicle, plot, state
|
|
|
|
| 306 |
|
| 307 |
|
| 308 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 412 |
or global_context["update_proxy"] == 0
|
| 413 |
):
|
| 414 |
logger.info(f"Updating the map plot... in conditional_update")
|
| 415 |
+
map_plot, _, _ = calculate_route_gradio(vehicle.location, vehicle.destination)
|
|
|
|
|
|
|
| 416 |
global_context["map"] = map_plot
|
| 417 |
return global_context["map"]
|
| 418 |
|
|
|
|
| 443 |
}
|
| 444 |
)
|
| 445 |
|
| 446 |
+
plot, _, _ = calculate_route_gradio(ORIGIN, DESTINATION)
|
| 447 |
global_context["map"] = plot
|
| 448 |
|
| 449 |
with gr.Row():
|
| 450 |
with gr.Column(scale=1, min_width=300):
|
| 451 |
vehicle_status = gr.JSON(
|
| 452 |
+
value=vehicle.model_dump(), label="Vehicle status"
|
| 453 |
)
|
| 454 |
time_picker = gr.Dropdown(
|
| 455 |
choices=hour_options,
|
|
|
|
| 644 |
ssl_verify=False,
|
| 645 |
share=False,
|
| 646 |
)
|
| 647 |
+
|
| 648 |
app = typer.Typer()
|
| 649 |
|
| 650 |
|