File size: 6,800 Bytes
c0aa211 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from kani_utils.base_kanis import StreamlitKani
from kani import AIParam, ai_function, ChatMessage, ToolCall
import streamlit as st
from typing import Annotated
import pandas as pd
import re
import yaml
from icdtree import ICDTree, get_tree
class ICDBot(StreamlitKani):
# Be sure to override the __init__ method to pass any parameters to the superclass
def __init__(self, *args, **kwargs):
#kwargs['chat_history'] = self._build_fewshot()
kwargs['system_prompt'] = """
You are ICDBot, a chatbot with programmatic access to the ICD-10-CM vocabulary. You can search for ICD-10 codes via and query for terms' parent and children nodes in the hierarchy.
* Use the `get_context` function to find the most specific code relevant to the user's query by navigating up and down the term hierarchy. For example, if the search for "fracture humerus" returns S42.2, you might use `get_context` to find the parent code S42.2 and its children codes S42.20, S42.21, ..., to identify the most specific code that matches the user's query.
""".strip()
super().__init__(*args, **kwargs)
# Define avatars for the agent and user
# Can be URLs or emojis
self.avatar = "🤖"
self.user_avatar = "👤"
# The name and greeting are shown at the start of the chat
# The greeting is not known to the LLM, it serves as a prompt for the user
self.name = "ICDBot"
self.greeting = "Hello, I'm ICDBot. Nice to meet you!"
# The description is shown in the sidebar and provides more information about the agent
self.description = "An agent with access to the ICD-10-CM vocabulary."
self.icd10_tree = get_tree()
@ai_function()
def semantic_search(self,
query: Annotated[str, AIParam(desc="Query string to search for.")],
limit: Annotated[int, AIParam(desc="Maximum number of results to return. Default to 10 unless necessary.")]=10,
offset: Annotated[int, AIParam(desc="Number of results to skip.")]=0):
"""Search for ICD-10 codes similar to the query using an embedding-search strategy."""
search_results = self.icd10_tree.semantic_search(query, limit=limit, offset=offset)
results_strings = [f"{node.code}: {node.short_desc}; Valid: {node.valid}" for node, _ in search_results]
result = yaml.dump({"Matches": results_strings})
return result
@ai_function()
def think(self, thought: Annotated[str, AIParam(desc="The thought to think.")]):
"""Use this function to record internal thoughts or notes, and for planning complex actions. You may call this function more than once in succession for complex plans."""
return None
@ai_function()
def get_context(self,
code: Annotated[str, AIParam(desc="ICD-10 code to retrieve context for")]):
"""Given an ICD-10 code, return information about its parent, children, and siblings."""
node = self.icd10_tree.get_node_by_code(code)
if node:
parent = node.parent
children = node.children
siblings = parent.children if parent else []
parent_string = f"{parent.code}: {parent.short_desc}; Valid: {parent.valid}" if parent else "No parent"
children_strings = [f"{child.code}: {child.short_desc}; Valid: {child.valid}" for child in children]
self_string = f"{node.code}: {node.short_desc}; Valid: {node.valid}"
siblings_strings = [f"{sibling.code}: {sibling.short_desc}; Valid: {sibling.valid}" for sibling in siblings]
# format the result as yaml
result = yaml.dump({"Node": self_string, "Parent": parent_string, "Siblings": siblings_strings, "Children": children_strings})
return result
else:
return f"Code {code} not found."
def _build_fewshot(self):
history = [ChatMessage.user("Is there a condition that corresponds to rheumatoid arthritis in a patient's wrist, when they also have neuropathy?")]
history.append(
ChatMessage.assistant(
content=None,
# use a walrus operator to save a reference to the tool call here...
tool_calls=[
tc1 := ToolCall.from_function("search_icd", pattern="[Rr]heumatoid arthritis", max_depth=1)
],
))
res = self.search_icd(pattern="[Rr]heumatoid arthritis", max_depth=1)
# assert that the string contains "M05.5"
assert "M05.5" in res
history.append(ChatMessage.function("search_icd", res, tc1.id))
history.append(
ChatMessage.assistant(
content=None,
# use a walrus operator to save a reference to the tool call here...
tool_calls=[
tc2 := ToolCall.from_function("get_children", code="M05.5")
],
))
res = self.get_children(code="M05.5")
# assert that the string contains "M05.53"
assert "M05.53" in res
history.append(ChatMessage.function("get_children", res, tc2.id))
history.append(
ChatMessage.assistant(
content=None,
# use a walrus operator to save a reference to the tool call here...
tool_calls=[
tc3 := ToolCall.from_function("get_children", code="M05.53")
],
))
res = self.get_children(code="M05.53")
# assert that the string contains "M05.539"
assert "M05.539" in res
history.append(ChatMessage.function("get_children", res, tc3.id))
# we'll also look for children of M05.539, but there will be none. Still, we'll add it to the history
history.append(
ChatMessage.assistant(
content=None,
# use a walrus operator to save a reference to the tool call here...
tool_calls=[
tc4 := ToolCall.from_function("get_children", code="M05.539")
],
))
res = self.get_children(code="M05.539")
# assert that the string is ""
assert res == "No children."
history.append(ChatMessage.function("get_children", res, tc4.id))
history.append(ChatMessage.assistant("Yes, there are multiple codes for rheumatoid arthritis in the wrist with polyneuropathy. Since you didn't mention left or right, M05.539 'Rheumatoid polyneurop w rheumatoid arthritis of unsp wrist' appears to be the most appropriate."))
return history |