GradLLM / runners /service.py
johnbridges's picture
.
5de81bd
raw
history blame
6.39 kB
import asyncio
from typing import Dict, Optional
from collections import defaultdict
from .rabbit_repo import RabbitRepo
from .config import settings
from .runners.base import ILLMRunner
class LLMService:
def __init__(self, publisher: RabbitRepo, runner_factory):
self._pub = publisher
self._runner_factory = runner_factory
self._sessions: Dict[str, dict] = {} # sessionId -> {"Runner": ILLMRunner, "FullSessionId": str}
self._ready = asyncio.Event()
self._ready.set() # if you have async load, clear and set after
async def init(self):
# If you have history to load, do here then self._ready.set()
pass
async def _set_result(self, obj: dict, message: str, success: bool, queue: str, check_system: bool=False):
obj["ResultMessage"] = message
obj["ResultSuccess"] = success
obj["LlmMessage"] = (f"<Success>{message}</Success>" if success else f"<Error>{message}</Error>")
# mirror your .NET rule (don’t publish for system llm if check_system is True)
if not (check_system and obj.get("IsSystemLlm")):
await self._pub.publish(queue, obj)
async def StartProcess(self, llmServiceObj: dict):
session_id = f"{llmServiceObj['RequestSessionId']}_{llmServiceObj['LLMRunnerType']}"
llmServiceObj["SessionId"] = session_id
# wait ready (max ~120s like .NET)
try:
await asyncio.wait_for(self._ready.wait(), timeout=120)
except asyncio.TimeoutError:
await self._set_result(llmServiceObj, "Timed out waiting for initialization.", False, "llmServiceMessage", True)
return
sess = self._sessions.get(session_id)
is_runner_null = not sess or not sess.get("Runner")
create_new = is_runner_null or sess["Runner"].IsStateFailed
if create_new:
if sess and sess.get("Runner"):
try:
await sess["Runner"].RemoveProcess(session_id)
except: pass
runner: ILLMRunner = await self._runner_factory(llmServiceObj)
if not runner.IsEnabled:
await self._set_result(llmServiceObj, f"{llmServiceObj['LLMRunnerType']} {settings.SERVICE_ID} not started as it is disabled.", True, "llmServiceMessage")
return
await self._set_result(llmServiceObj, f"Starting {runner.Type} {settings.SERVICE_ID} Expert", True, "llmServiceMessage", True)
await runner.StartProcess(llmServiceObj)
self._sessions[session_id] = {"Runner": runner, "FullSessionId": session_id}
if settings.SERVICE_ID == "monitor":
await self._set_result(llmServiceObj, f"Hi i'm {runner.Type} your Network Monitor Assistant. How can I help you.", True, "llmServiceMessage", True)
await self._pub.publish("llmServiceStarted", llmServiceObj)
async def RemoveSession(self, llmServiceObj: dict):
# Behaves like your RemoveAllSessionIdProcesses (prefix match)
base = llmServiceObj.get("SessionId","").split("_")[0]
targets = [k for k in self._sessions.keys() if k.startswith(base + "_")]
msgs = []
ok = True
for sid in targets:
s = self._sessions.get(sid)
if s and s.get("Runner"):
try:
await s["Runner"].RemoveProcess(sid)
s["Runner"] = None
msgs.append(sid)
except Exception as e:
ok = False
msgs.append(f"Error {sid}: {e}")
if ok:
await self._set_result(llmServiceObj, f"Success: Removed sessions for {' '.join(msgs)}", True, "llmSessionMessage", True)
else:
await self._set_result(llmServiceObj, " ".join(msgs), False, "llmServiceMessage")
async def StopRequest(self, llmServiceObj: dict):
sid = llmServiceObj.get("SessionId","")
s = self._sessions.get(sid)
if not s or not s.get("Runner"):
await self._set_result(llmServiceObj, f"Error: Runner missing for session {sid}.", False, "llmServiceMessage")
return
await s["Runner"].StopRequest(sid)
await self._set_result(llmServiceObj, f"Success {s['Runner'].Type} {settings.SERVICE_ID} Assistant output has been halted", True, "llmServiceMessage", True)
async def UserInput(self, llmServiceObj: dict):
sid = llmServiceObj.get("SessionId","")
s = self._sessions.get(sid)
if not s or not s.get("Runner"):
await self._set_result(llmServiceObj, f"Error: SessionId {sid} has no running process.", False, "llmServiceMessage")
return
r: ILLMRunner = s["Runner"]
if r.IsStateStarting:
await self._set_result(llmServiceObj, "Please wait, the assistant is starting...", False, "llmServiceMessage")
return
if r.IsStateFailed:
await self._set_result(llmServiceObj, "The Assistant is stopped. Try reloading.", False, "llmServiceMessage")
return
await r.SendInputAndGetResponse(llmServiceObj)
# emitter side can push partials directly to queues if desired
async def QueryIndexResult(self, queryIndexRequest: dict):
# Adapted to your behavior: concatenate outputs, publish completion via internal coordinator if needed
try:
rag_data = "\n".join([qr.get("Output","") for qr in (queryIndexRequest.get("QueryResults") or [])])
# You signal _queryCoordinator.CompleteQuery in .NET; here you may forward/publish result…
# Example: include rag data in a service message to the session
await self._pub.publish("llmServiceMessage", {
"ResultSuccess": queryIndexRequest.get("Success", False),
"ResultMessage": queryIndexRequest.get("Message",""),
"Data": rag_data,
})
except Exception as e:
await self._pub.publish("llmServiceMessage", {"ResultSuccess": False, "ResultMessage": str(e)})
async def GetFunctionRegistry(self, filtered: bool = False):
# Plug in your registry
data = {"FunctionCatalogJson": "{}", "Filtered": filtered}
await self._pub.publish("llmServiceMessage", {"ResultSuccess": True, "ResultMessage": f"Success : Got GetFunctionCatalogJson : {data}"})