from kani_utils.base_kanis import StreamlitKani from kani import AIParam, ai_function, ChatMessage, ToolCall import streamlit as st from typing import Annotated import pandas as pd import re import yaml from icdtree import ICDTree, get_tree class ICDBot(StreamlitKani): # Be sure to override the __init__ method to pass any parameters to the superclass def __init__(self, *args, **kwargs): #kwargs['chat_history'] = self._build_fewshot() kwargs['system_prompt'] = """ You are ICDBot, a chatbot with programmatic access to the ICD-10-CM vocabulary. You can search for ICD-10 codes via and query for terms' parent and children nodes in the hierarchy. * Use the `get_context` function to find the most specific code relevant to the user's query by navigating up and down the term hierarchy. For example, if the search for "fracture humerus" returns S42.2, you might use `get_context` to find the parent code S42.2 and its children codes S42.20, S42.21, ..., to identify the most specific code that matches the user's query. """.strip() super().__init__(*args, **kwargs) # Define avatars for the agent and user # Can be URLs or emojis self.avatar = "🤖" self.user_avatar = "👤" # The name and greeting are shown at the start of the chat # The greeting is not known to the LLM, it serves as a prompt for the user self.name = "ICDBot" self.greeting = "Hello, I'm ICDBot. Nice to meet you!" # The description is shown in the sidebar and provides more information about the agent self.description = "An agent with access to the ICD-10-CM vocabulary." self.icd10_tree = get_tree() @ai_function() def semantic_search(self, query: Annotated[str, AIParam(desc="Query string to search for.")], limit: Annotated[int, AIParam(desc="Maximum number of results to return. Default to 10 unless necessary.")]=10, offset: Annotated[int, AIParam(desc="Number of results to skip.")]=0): """Search for ICD-10 codes similar to the query using an embedding-search strategy.""" search_results = self.icd10_tree.semantic_search(query, limit=limit, offset=offset) results_strings = [f"{node.code}: {node.short_desc}; Valid: {node.valid}" for node, _ in search_results] result = yaml.dump({"Matches": results_strings}) return result @ai_function() def think(self, thought: Annotated[str, AIParam(desc="The thought to think.")]): """Use this function to record internal thoughts or notes, and for planning complex actions. You may call this function more than once in succession for complex plans.""" return None @ai_function() def get_context(self, code: Annotated[str, AIParam(desc="ICD-10 code to retrieve context for")]): """Given an ICD-10 code, return information about its parent, children, and siblings.""" node = self.icd10_tree.get_node_by_code(code) if node: parent = node.parent children = node.children siblings = parent.children if parent else [] parent_string = f"{parent.code}: {parent.short_desc}; Valid: {parent.valid}" if parent else "No parent" children_strings = [f"{child.code}: {child.short_desc}; Valid: {child.valid}" for child in children] self_string = f"{node.code}: {node.short_desc}; Valid: {node.valid}" siblings_strings = [f"{sibling.code}: {sibling.short_desc}; Valid: {sibling.valid}" for sibling in siblings] # format the result as yaml result = yaml.dump({"Node": self_string, "Parent": parent_string, "Siblings": siblings_strings, "Children": children_strings}) return result else: return f"Code {code} not found." def _build_fewshot(self): history = [ChatMessage.user("Is there a condition that corresponds to rheumatoid arthritis in a patient's wrist, when they also have neuropathy?")] history.append( ChatMessage.assistant( content=None, # use a walrus operator to save a reference to the tool call here... tool_calls=[ tc1 := ToolCall.from_function("search_icd", pattern="[Rr]heumatoid arthritis", max_depth=1) ], )) res = self.search_icd(pattern="[Rr]heumatoid arthritis", max_depth=1) # assert that the string contains "M05.5" assert "M05.5" in res history.append(ChatMessage.function("search_icd", res, tc1.id)) history.append( ChatMessage.assistant( content=None, # use a walrus operator to save a reference to the tool call here... tool_calls=[ tc2 := ToolCall.from_function("get_children", code="M05.5") ], )) res = self.get_children(code="M05.5") # assert that the string contains "M05.53" assert "M05.53" in res history.append(ChatMessage.function("get_children", res, tc2.id)) history.append( ChatMessage.assistant( content=None, # use a walrus operator to save a reference to the tool call here... tool_calls=[ tc3 := ToolCall.from_function("get_children", code="M05.53") ], )) res = self.get_children(code="M05.53") # assert that the string contains "M05.539" assert "M05.539" in res history.append(ChatMessage.function("get_children", res, tc3.id)) # we'll also look for children of M05.539, but there will be none. Still, we'll add it to the history history.append( ChatMessage.assistant( content=None, # use a walrus operator to save a reference to the tool call here... tool_calls=[ tc4 := ToolCall.from_function("get_children", code="M05.539") ], )) res = self.get_children(code="M05.539") # assert that the string is "" assert res == "No children." history.append(ChatMessage.function("get_children", res, tc4.id)) history.append(ChatMessage.assistant("Yes, there are multiple codes for rheumatoid arthritis in the wrist with polyneuropathy. Since you didn't mention left or right, M05.539 'Rheumatoid polyneurop w rheumatoid arthritis of unsp wrist' appears to be the most appropriate.")) return history