johnbridges commited on
Commit
7630510
·
1 Parent(s): ff11fee
Files changed (3) hide show
  1. app.py +41 -21
  2. listener.py +18 -6
  3. rabbit_base.py +105 -31
app.py CHANGED
@@ -2,6 +2,7 @@
2
  import asyncio
3
  import gradio as gr
4
  from fastapi import FastAPI
 
5
 
6
  from config import settings
7
  from rabbit_base import RabbitBase
@@ -10,6 +11,16 @@ from rabbit_repo import RabbitRepo
10
  from service import LLMService
11
  from runners.base import ILLMRunner
12
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # --- Runner factory (stub) ---
15
  class EchoRunner(ILLMRunner):
@@ -46,11 +57,11 @@ handlers = {
46
  "getFunctionRegistryFiltered": h_getreg_f,
47
  }
48
 
49
- # --- Listener wiring (fix: needs base + instance_name) ---
50
  base = RabbitBase()
51
  listener = RabbitListenerBase(
52
  base,
53
- instance_name=settings.RABBIT_INSTANCE_NAME, # <- queue prefix, like your .NET instance
54
  handlers=handlers,
55
  )
56
 
@@ -61,17 +72,27 @@ DECLS = [
61
  {"ExchangeName": f"llmUserInput{settings.SERVICE_ID}", "FuncName": "llmUserInput",
62
  "MessageTimeout": 600_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
63
  {"ExchangeName": f"llmRemoveSession{settings.SERVICE_ID}", "FuncName": "llmRemoveSession",
64
- "MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
65
  {"ExchangeName": f"llmStopRequest{settings.SERVICE_ID}", "FuncName": "llmStopRequest",
66
- "MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
67
  {"ExchangeName": f"queryIndexResult{settings.SERVICE_ID}", "FuncName": "queryIndexResult",
68
- "MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
69
  {"ExchangeName": f"getFunctionRegistry{settings.SERVICE_ID}", "FuncName": "getFunctionRegistry",
70
- "MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
71
  {"ExchangeName": f"getFunctionRegistryFiltered{settings.SERVICE_ID}", "FuncName": "getFunctionRegistryFiltered",
72
- "MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
73
  ]
74
 
 
 
 
 
 
 
 
 
 
 
75
  # --- Gradio UI (for smoke test) ---
76
  async def ping():
77
  return "ok"
@@ -82,25 +103,24 @@ with gr.Blocks() as demo:
82
  out = gr.Textbox()
83
  btn.click(ping, inputs=None, outputs=out)
84
 
85
- # --- FastAPI mount + lifecycle ---
86
- app = FastAPI()
87
- app = gr.mount_gradio_app(app, demo, path="/")
88
-
89
- @app.get("/health")
90
- async def health():
91
- return {"status": "ok"}
92
-
93
- @app.on_event("startup")
94
- async def on_start():
95
  await publisher.connect()
96
  await service.init()
97
  await listener.start(DECLS)
 
 
 
 
98
 
99
- @app.on_event("shutdown")
100
- async def on_stop():
101
- # Optionally close connections/channels if you track them
102
- pass
103
 
 
 
 
104
 
105
  if __name__ == "__main__":
106
  import uvicorn
 
2
  import asyncio
3
  import gradio as gr
4
  from fastapi import FastAPI
5
+ from contextlib import asynccontextmanager
6
 
7
  from config import settings
8
  from rabbit_base import RabbitBase
 
11
  from service import LLMService
12
  from runners.base import ILLMRunner
13
 
14
+ # --- Optional ZeroGPU hook ---
15
+ # If your Space uses ZeroGPU hardware, this satisfies the startup check.
16
+ # If you're on CPU hardware, this is harmless.
17
+ try:
18
+ import spaces
19
+ ZERO_GPU_AVAILABLE = True
20
+ except Exception:
21
+ spaces = None
22
+ ZERO_GPU_AVAILABLE = False
23
+
24
 
25
  # --- Runner factory (stub) ---
26
  class EchoRunner(ILLMRunner):
 
57
  "getFunctionRegistryFiltered": h_getreg_f,
58
  }
59
 
60
+ # --- Listener wiring (needs base + instance_name) ---
61
  base = RabbitBase()
62
  listener = RabbitListenerBase(
63
  base,
64
+ instance_name=settings.RABBIT_INSTANCE_NAME, # queue prefix like your .NET instance
65
  handlers=handlers,
66
  )
67
 
 
72
  {"ExchangeName": f"llmUserInput{settings.SERVICE_ID}", "FuncName": "llmUserInput",
73
  "MessageTimeout": 600_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
74
  {"ExchangeName": f"llmRemoveSession{settings.SERVICE_ID}", "FuncName": "llmRemoveSession",
75
+ "MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
76
  {"ExchangeName": f"llmStopRequest{settings.SERVICE_ID}", "FuncName": "llmStopRequest",
77
+ "MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
78
  {"ExchangeName": f"queryIndexResult{settings.SERVICE_ID}", "FuncName": "queryIndexResult",
79
+ "MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
80
  {"ExchangeName": f"getFunctionRegistry{settings.SERVICE_ID}", "FuncName": "getFunctionRegistry",
81
+ "MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
82
  {"ExchangeName": f"getFunctionRegistryFiltered{settings.SERVICE_ID}", "FuncName": "getFunctionRegistryFiltered",
83
+ "MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
84
  ]
85
 
86
+ # --- ZeroGPU detection function (no-op) ---
87
+ # This only exists so HF Spaces sees that you "have" a GPU entrypoint on ZeroGPU.
88
+ if ZERO_GPU_AVAILABLE:
89
+ @spaces.GPU() # duration can be omitted; we don't invoke it at startup
90
+ def gpu_ready_probe() -> str:
91
+ # Do not allocate any large tensors; just a trivial statement.
92
+ # Presence of this function is enough for the ZeroGPU startup check.
93
+ return "gpu-probe-ok"
94
+
95
+
96
  # --- Gradio UI (for smoke test) ---
97
  async def ping():
98
  return "ok"
 
103
  out = gr.Textbox()
104
  btn.click(ping, inputs=None, outputs=out)
105
 
106
+ # --- FastAPI app with lifespan (replaces deprecated @on_event) ---
107
+ @asynccontextmanager
108
+ async def lifespan(_app: FastAPI):
109
+ # startup
 
 
 
 
 
 
110
  await publisher.connect()
111
  await service.init()
112
  await listener.start(DECLS)
113
+ yield
114
+ # shutdown (optional cleanup)
115
+ # await publisher.close() # if your RabbitRepo exposes this
116
+ # await listener.stop() # if you implement stop()
117
 
118
+ app = FastAPI(lifespan=lifespan)
119
+ app = gr.mount_gradio_app(app, demo, path="/")
 
 
120
 
121
+ @app.get("/health")
122
+ async def health():
123
+ return {"status": "ok"}
124
 
125
  if __name__ == "__main__":
126
  import uvicorn
listener.py CHANGED
@@ -1,11 +1,18 @@
1
- # listener.py
2
  import json
 
3
  from typing import Callable, Awaitable, Dict, Any, List
4
 
5
  import aio_pika
6
 
7
  Handler = Callable[[Any], Awaitable[None]] # payload is envelope["data"]
8
 
 
 
 
 
 
 
 
9
 
10
  class RabbitListenerBase:
11
  def __init__(self, base, instance_name: str, handlers: Dict[str, Handler]):
@@ -37,13 +44,18 @@ class RabbitListenerBase:
37
  async def _on_msg(msg: aio_pika.IncomingMessage):
38
  async with msg.process():
39
  try:
40
- envelope = json.loads(msg.body.decode("utf-8"))
41
- # Expect CloudEvent-ish envelope; we only need the 'data' field
 
 
 
 
 
 
42
  data = envelope.get("data", None)
43
  if handler:
44
  await handler(data)
45
- except Exception:
46
- # Avoid requeue storms; add logging if you want
47
- pass
48
 
49
  return _on_msg
 
 
1
  import json
2
+ import logging
3
  from typing import Callable, Awaitable, Dict, Any, List
4
 
5
  import aio_pika
6
 
7
  Handler = Callable[[Any], Awaitable[None]] # payload is envelope["data"]
8
 
9
+ # Configure root logger if not already configured
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
13
+ )
14
+ logger = logging.getLogger(__name__)
15
+
16
 
17
  class RabbitListenerBase:
18
  def __init__(self, base, instance_name: str, handlers: Dict[str, Handler]):
 
44
  async def _on_msg(msg: aio_pika.IncomingMessage):
45
  async with msg.process():
46
  try:
47
+ raw_body = msg.body.decode("utf-8", errors="replace")
48
+ logger.info(
49
+ "Received message for handler '%s': %s",
50
+ func_name,
51
+ raw_body
52
+ )
53
+
54
+ envelope = json.loads(raw_body)
55
  data = envelope.get("data", None)
56
  if handler:
57
  await handler(data)
58
+ except Exception as e:
59
+ logger.exception("Error processing message for '%s'", func_name)
 
60
 
61
  return _on_msg
rabbit_base.py CHANGED
@@ -1,37 +1,44 @@
 
1
  from typing import Callable, Dict, List, Optional
2
- import aio_pika
3
  from urllib.parse import urlsplit, unquote
4
- from config import settings
5
  import ssl
 
 
 
 
 
6
 
7
  ExchangeResolver = Callable[[str], str] # exchangeName -> exchangeType
8
 
9
- # rabbit_base.py
10
- import aio_pika
 
 
 
 
 
11
 
12
  def _normalize_exchange_type(val: str) -> aio_pika.ExchangeType:
13
- # 1) Try attribute by NAME (DIRECT/FANOUT/TOPIC/HEADERS)
14
  if isinstance(val, str):
15
  name = val.upper()
16
  if hasattr(aio_pika.ExchangeType, name):
17
  return getattr(aio_pika.ExchangeType, name)
18
- # 2) Try enum by VALUE ("direct"/"fanout"/"topic"/"headers")
19
  try:
20
  return aio_pika.ExchangeType(val.lower())
21
  except Exception:
22
  pass
23
- # 3) Default
24
  return aio_pika.ExchangeType.TOPIC
25
 
26
  def _parse_amqp_url(url: str) -> dict:
27
  parts = urlsplit(url)
28
  return {
 
29
  "host": parts.hostname or "localhost",
30
  "port": parts.port or (5671 if parts.scheme == "amqps" else 5672),
31
  "login": parts.username or "guest",
32
  "password": parts.password or "guest",
33
  "virtualhost": unquote(parts.path[1:] or "/"),
34
- "ssl": parts.scheme == "amqps",
35
  }
36
 
37
 
@@ -44,37 +51,90 @@ class RabbitBase:
44
  lambda _: settings.RABBIT_EXCHANGE_TYPE
45
  )
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  async def connect(self) -> None:
48
  if self._conn and not self._conn.is_closed:
49
- return
 
50
 
51
  conn_kwargs = _parse_amqp_url(str(settings.AMQP_URL))
52
 
53
- # Build an SSLContext that DISABLES verification
 
 
 
 
 
 
 
 
 
 
 
54
  ssl_ctx = None
55
  if conn_kwargs.get("ssl"):
56
  ssl_ctx = ssl.create_default_context()
57
  ssl_ctx.check_hostname = False
58
  ssl_ctx.verify_mode = ssl.CERT_NONE
 
59
 
60
- # Pass ssl_context explicitly – this is what aio-pika supports
61
- self._conn = await aio_pika.connect_robust(
62
- **conn_kwargs,
63
- ssl_context=ssl_ctx # <- key bit
64
- )
65
- self._chan = await self._conn.channel()
66
- await self._chan.set_qos(prefetch_count=settings.RABBIT_PREFETCH)
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  async def ensure_exchange(self, name: str) -> aio_pika.Exchange:
69
  await self.connect()
70
  if name in self._exchanges:
71
  return self._exchanges[name]
 
72
  ex_type_str = self._exchange_type_resolver(name) # e.g. "direct"
73
  ex_type = _normalize_exchange_type(ex_type_str)
74
- ex = await self._chan.declare_exchange(name, ex_type, durable=True)
75
-
76
- self._exchanges[name] = ex
77
- return ex
 
 
 
 
 
78
 
79
  async def declare_queue_bind(
80
  self,
@@ -85,16 +145,30 @@ class RabbitBase:
85
  ):
86
  await self.connect()
87
  ex = await self.ensure_exchange(exchange)
 
88
  args: Dict[str, int] = {}
89
  if ttl_ms:
90
  args["x-message-ttl"] = ttl_ms
91
- q = await self._chan.declare_queue(
92
- queue_name,
93
- durable=True,
94
- exclusive=False,
95
- auto_delete=True,
96
- arguments=args,
97
- )
98
- for rk in routing_keys or [""]:
99
- await q.bind(ex, rk)
100
- return q
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rabbit_base.py
2
  from typing import Callable, Dict, List, Optional
 
3
  from urllib.parse import urlsplit, unquote
 
4
  import ssl
5
+ import json
6
+ import logging
7
+ import aio_pika
8
+
9
+ from config import settings
10
 
11
  ExchangeResolver = Callable[[str], str] # exchangeName -> exchangeType
12
 
13
+ # --- logging setup (keep simple; inherit root handlers if already configured) ---
14
+ logger = logging.getLogger(__name__)
15
+ if not logger.handlers:
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
19
+ )
20
 
21
  def _normalize_exchange_type(val: str) -> aio_pika.ExchangeType:
 
22
  if isinstance(val, str):
23
  name = val.upper()
24
  if hasattr(aio_pika.ExchangeType, name):
25
  return getattr(aio_pika.ExchangeType, name)
 
26
  try:
27
  return aio_pika.ExchangeType(val.lower())
28
  except Exception:
29
  pass
 
30
  return aio_pika.ExchangeType.TOPIC
31
 
32
  def _parse_amqp_url(url: str) -> dict:
33
  parts = urlsplit(url)
34
  return {
35
+ "scheme": parts.scheme or "amqp",
36
  "host": parts.hostname or "localhost",
37
  "port": parts.port or (5671 if parts.scheme == "amqps" else 5672),
38
  "login": parts.username or "guest",
39
  "password": parts.password or "guest",
40
  "virtualhost": unquote(parts.path[1:] or "/"),
41
+ "ssl": (parts.scheme == "amqps"),
42
  }
43
 
44
 
 
51
  lambda _: settings.RABBIT_EXCHANGE_TYPE
52
  )
53
 
54
+ # -------- Status helpers --------
55
+ def is_connected(self) -> bool:
56
+ return bool(
57
+ self._conn and not self._conn.is_closed and
58
+ self._chan and not self._chan.is_closed
59
+ )
60
+
61
+ async def close(self) -> None:
62
+ try:
63
+ if self._chan and not self._chan.is_closed:
64
+ logger.info("Closing AMQP channel")
65
+ await self._chan.close()
66
+ finally:
67
+ self._chan = None
68
+ try:
69
+ if self._conn and not self._conn.is_closed:
70
+ logger.info("Closing AMQP connection")
71
+ await self._conn.close()
72
+ finally:
73
+ self._conn = None
74
+ logger.info("AMQP connection closed")
75
+
76
+ # -------- Core ops --------
77
  async def connect(self) -> None:
78
  if self._conn and not self._conn.is_closed:
79
+ if self._chan and not self._chan.is_closed:
80
+ return
81
 
82
  conn_kwargs = _parse_amqp_url(str(settings.AMQP_URL))
83
 
84
+ # Log connection target (mask password)
85
+ safe_target = {
86
+ "scheme": conn_kwargs["scheme"],
87
+ "host": conn_kwargs["host"],
88
+ "port": conn_kwargs["port"],
89
+ "virtualhost": conn_kwargs["virtualhost"],
90
+ "ssl": conn_kwargs["ssl"],
91
+ "login": conn_kwargs["login"],
92
+ }
93
+ logger.info("AMQP connect -> %s", json.dumps(safe_target))
94
+
95
+ # TLS (intentionally disabling verification if requested)
96
  ssl_ctx = None
97
  if conn_kwargs.get("ssl"):
98
  ssl_ctx = ssl.create_default_context()
99
  ssl_ctx.check_hostname = False
100
  ssl_ctx.verify_mode = ssl.CERT_NONE
101
+ logger.warning("AMQP TLS verification is DISABLED (CERT_NONE)")
102
 
103
+ try:
104
+ self._conn = await aio_pika.connect_robust(
105
+ host=conn_kwargs["host"],
106
+ port=conn_kwargs["port"],
107
+ login=conn_kwargs["login"],
108
+ password=conn_kwargs["password"],
109
+ virtualhost=conn_kwargs["virtualhost"],
110
+ ssl=conn_kwargs["ssl"],
111
+ ssl_context=ssl_ctx,
112
+ )
113
+ logger.info("AMQP connection established")
114
+ self._chan = await self._conn.channel()
115
+ logger.info("AMQP channel created")
116
+ await self._chan.set_qos(prefetch_count=settings.RABBIT_PREFETCH)
117
+ logger.info("AMQP QoS set (prefetch=%s)", settings.RABBIT_PREFETCH)
118
+ except Exception:
119
+ logger.exception("AMQP connection/channel setup failed")
120
+ raise
121
 
122
  async def ensure_exchange(self, name: str) -> aio_pika.Exchange:
123
  await self.connect()
124
  if name in self._exchanges:
125
  return self._exchanges[name]
126
+
127
  ex_type_str = self._exchange_type_resolver(name) # e.g. "direct"
128
  ex_type = _normalize_exchange_type(ex_type_str)
129
+
130
+ try:
131
+ ex = await self._chan.declare_exchange(name, ex_type, durable=True)
132
+ self._exchanges[name] = ex
133
+ logger.info("Exchange declared: name=%s type=%s durable=true", name, ex_type.value)
134
+ return ex
135
+ except Exception:
136
+ logger.exception("Failed declaring exchange: %s (%s)", name, ex_type_str)
137
+ raise
138
 
139
  async def declare_queue_bind(
140
  self,
 
145
  ):
146
  await self.connect()
147
  ex = await self.ensure_exchange(exchange)
148
+
149
  args: Dict[str, int] = {}
150
  if ttl_ms:
151
  args["x-message-ttl"] = ttl_ms
152
+
153
+ try:
154
+ q = await self._chan.declare_queue(
155
+ queue_name,
156
+ durable=True,
157
+ exclusive=False,
158
+ auto_delete=True,
159
+ arguments=args,
160
+ )
161
+ logger.info(
162
+ "Queue declared: name=%s durable=true auto_delete=true args=%s",
163
+ queue_name, args or {}
164
+ )
165
+ for rk in routing_keys or [""]:
166
+ await q.bind(ex, rk)
167
+ logger.info("Queue bound: queue=%s exchange=%s rk='%s'", queue_name, exchange, rk)
168
+ return q
169
+ except Exception:
170
+ logger.exception(
171
+ "Failed declare/bind queue: queue=%s exchange=%s rks=%s args=%s",
172
+ queue_name, exchange, routing_keys, args or {}
173
+ )
174
+ raise