|
|
from kani_utils.base_kanis import StreamlitKani |
|
|
from kani import AIParam, ai_function, ChatMessage, ToolCall |
|
|
import streamlit as st |
|
|
|
|
|
from typing import Annotated |
|
|
import pandas as pd |
|
|
import re |
|
|
import yaml |
|
|
|
|
|
from icdtree import ICDTree, get_tree |
|
|
|
|
|
|
|
|
class ICDBot(StreamlitKani): |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
|
|
|
|
|
kwargs['system_prompt'] = """ |
|
|
You are ICDBot, a chatbot with programmatic access to the ICD-10-CM vocabulary. You can search for ICD-10 codes via and query for terms' parent and children nodes in the hierarchy. |
|
|
|
|
|
* Use the `get_context` function to find the most specific code relevant to the user's query by navigating up and down the term hierarchy. For example, if the search for "fracture humerus" returns S42.2, you might use `get_context` to find the parent code S42.2 and its children codes S42.20, S42.21, ..., to identify the most specific code that matches the user's query. |
|
|
""".strip() |
|
|
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
self.avatar = "🤖" |
|
|
self.user_avatar = "👤" |
|
|
|
|
|
|
|
|
|
|
|
self.name = "ICDBot" |
|
|
self.greeting = "Hello, I'm ICDBot. Nice to meet you!" |
|
|
|
|
|
|
|
|
self.description = "An agent with access to the ICD-10-CM vocabulary." |
|
|
|
|
|
self.icd10_tree = get_tree() |
|
|
|
|
|
|
|
|
@ai_function() |
|
|
def semantic_search(self, |
|
|
query: Annotated[str, AIParam(desc="Query string to search for.")], |
|
|
limit: Annotated[int, AIParam(desc="Maximum number of results to return. Default to 10 unless necessary.")]=10, |
|
|
offset: Annotated[int, AIParam(desc="Number of results to skip.")]=0): |
|
|
|
|
|
"""Search for ICD-10 codes similar to the query using an embedding-search strategy.""" |
|
|
|
|
|
search_results = self.icd10_tree.semantic_search(query, limit=limit, offset=offset) |
|
|
|
|
|
results_strings = [f"{node.code}: {node.short_desc}; Valid: {node.valid}" for node, _ in search_results] |
|
|
result = yaml.dump({"Matches": results_strings}) |
|
|
return result |
|
|
|
|
|
@ai_function() |
|
|
def think(self, thought: Annotated[str, AIParam(desc="The thought to think.")]): |
|
|
"""Use this function to record internal thoughts or notes, and for planning complex actions. You may call this function more than once in succession for complex plans.""" |
|
|
return None |
|
|
|
|
|
@ai_function() |
|
|
def get_context(self, |
|
|
code: Annotated[str, AIParam(desc="ICD-10 code to retrieve context for")]): |
|
|
"""Given an ICD-10 code, return information about its parent, children, and siblings.""" |
|
|
node = self.icd10_tree.get_node_by_code(code) |
|
|
if node: |
|
|
parent = node.parent |
|
|
children = node.children |
|
|
siblings = parent.children if parent else [] |
|
|
parent_string = f"{parent.code}: {parent.short_desc}; Valid: {parent.valid}" if parent else "No parent" |
|
|
children_strings = [f"{child.code}: {child.short_desc}; Valid: {child.valid}" for child in children] |
|
|
self_string = f"{node.code}: {node.short_desc}; Valid: {node.valid}" |
|
|
siblings_strings = [f"{sibling.code}: {sibling.short_desc}; Valid: {sibling.valid}" for sibling in siblings] |
|
|
|
|
|
result = yaml.dump({"Node": self_string, "Parent": parent_string, "Siblings": siblings_strings, "Children": children_strings}) |
|
|
|
|
|
return result |
|
|
else: |
|
|
return f"Code {code} not found." |
|
|
|
|
|
|
|
|
|
|
|
def _build_fewshot(self): |
|
|
|
|
|
history = [ChatMessage.user("Is there a condition that corresponds to rheumatoid arthritis in a patient's wrist, when they also have neuropathy?")] |
|
|
history.append( |
|
|
ChatMessage.assistant( |
|
|
content=None, |
|
|
|
|
|
tool_calls=[ |
|
|
tc1 := ToolCall.from_function("search_icd", pattern="[Rr]heumatoid arthritis", max_depth=1) |
|
|
], |
|
|
)) |
|
|
|
|
|
res = self.search_icd(pattern="[Rr]heumatoid arthritis", max_depth=1) |
|
|
|
|
|
assert "M05.5" in res |
|
|
|
|
|
history.append(ChatMessage.function("search_icd", res, tc1.id)) |
|
|
|
|
|
history.append( |
|
|
ChatMessage.assistant( |
|
|
content=None, |
|
|
|
|
|
tool_calls=[ |
|
|
tc2 := ToolCall.from_function("get_children", code="M05.5") |
|
|
], |
|
|
)) |
|
|
|
|
|
res = self.get_children(code="M05.5") |
|
|
|
|
|
assert "M05.53" in res |
|
|
|
|
|
history.append(ChatMessage.function("get_children", res, tc2.id)) |
|
|
|
|
|
history.append( |
|
|
ChatMessage.assistant( |
|
|
content=None, |
|
|
|
|
|
tool_calls=[ |
|
|
tc3 := ToolCall.from_function("get_children", code="M05.53") |
|
|
], |
|
|
)) |
|
|
|
|
|
res = self.get_children(code="M05.53") |
|
|
|
|
|
assert "M05.539" in res |
|
|
|
|
|
history.append(ChatMessage.function("get_children", res, tc3.id)) |
|
|
|
|
|
|
|
|
history.append( |
|
|
ChatMessage.assistant( |
|
|
content=None, |
|
|
|
|
|
tool_calls=[ |
|
|
tc4 := ToolCall.from_function("get_children", code="M05.539") |
|
|
], |
|
|
)) |
|
|
|
|
|
res = self.get_children(code="M05.539") |
|
|
|
|
|
assert res == "No children." |
|
|
|
|
|
history.append(ChatMessage.function("get_children", res, tc4.id)) |
|
|
|
|
|
history.append(ChatMessage.assistant("Yes, there are multiple codes for rheumatoid arthritis in the wrist with polyneuropathy. Since you didn't mention left or right, M05.539 'Rheumatoid polyneurop w rheumatoid arthritis of unsp wrist' appears to be the most appropriate.")) |
|
|
return history |