File size: 6,800 Bytes
c0aa211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from kani_utils.base_kanis import StreamlitKani
from kani import AIParam, ai_function, ChatMessage, ToolCall
import streamlit as st

from typing import Annotated
import pandas as pd
import re
import yaml

from icdtree import ICDTree, get_tree


class ICDBot(StreamlitKani):
    # Be sure to override the __init__ method to pass any parameters to the superclass
    def __init__(self, *args, **kwargs):
        #kwargs['chat_history'] = self._build_fewshot()

        kwargs['system_prompt'] = """
        You are ICDBot, a chatbot with programmatic access to the ICD-10-CM vocabulary. You can search for ICD-10 codes via and query for terms' parent and children nodes in the hierarchy.

        * Use the `get_context` function to find the most specific code relevant to the user's query by navigating up and down the term hierarchy. For example, if the search for "fracture humerus" returns S42.2, you might use `get_context` to find the parent code S42.2 and its children codes S42.20, S42.21, ..., to identify the most specific code that matches the user's query.
        """.strip()

        super().__init__(*args, **kwargs)

        # Define avatars for the agent and user
        # Can be URLs or emojis
        self.avatar = "🤖"
        self.user_avatar = "👤"

        # The name and greeting are shown at the start of the chat
        # The greeting is not known to the LLM, it serves as a prompt for the user
        self.name = "ICDBot"
        self.greeting = "Hello, I'm ICDBot. Nice to meet you!"

        # The description is shown in the sidebar and provides more information about the agent
        self.description = "An agent with access to the ICD-10-CM vocabulary."

        self.icd10_tree = get_tree()


    @ai_function()
    def semantic_search(self,
                        query: Annotated[str, AIParam(desc="Query string to search for.")],
                        limit: Annotated[int, AIParam(desc="Maximum number of results to return. Default to 10 unless necessary.")]=10,
                        offset: Annotated[int, AIParam(desc="Number of results to skip.")]=0):
        
        """Search for ICD-10 codes similar to the query using an embedding-search strategy."""

        search_results = self.icd10_tree.semantic_search(query, limit=limit, offset=offset)

        results_strings = [f"{node.code}: {node.short_desc}; Valid: {node.valid}" for node, _ in search_results]
        result = yaml.dump({"Matches": results_strings})
        return result

    @ai_function()
    def think(self, thought: Annotated[str, AIParam(desc="The thought to think.")]):
        """Use this function to record internal thoughts or notes, and for planning complex actions. You may call this function more than once in succession for complex plans."""
        return None

    @ai_function()
    def get_context(self, 
                    code: Annotated[str, AIParam(desc="ICD-10 code to retrieve context for")]):
        """Given an ICD-10 code, return information about its parent, children, and siblings."""
        node = self.icd10_tree.get_node_by_code(code)
        if node:
            parent = node.parent
            children = node.children
            siblings = parent.children if parent else []
            parent_string = f"{parent.code}: {parent.short_desc}; Valid: {parent.valid}" if parent else "No parent"
            children_strings = [f"{child.code}: {child.short_desc}; Valid: {child.valid}" for child in children]
            self_string = f"{node.code}: {node.short_desc}; Valid: {node.valid}"
            siblings_strings = [f"{sibling.code}: {sibling.short_desc}; Valid: {sibling.valid}" for sibling in siblings]
            # format the result as yaml
            result = yaml.dump({"Node": self_string, "Parent": parent_string, "Siblings": siblings_strings, "Children": children_strings})

            return result
        else:
            return f"Code {code} not found."



    def _build_fewshot(self):

        history = [ChatMessage.user("Is there a condition that corresponds to rheumatoid arthritis in a patient's wrist, when they also have neuropathy?")]
        history.append(
                ChatMessage.assistant(
                    content=None,
                    # use a walrus operator to save a reference to the tool call here...
                    tool_calls=[
                        tc1 := ToolCall.from_function("search_icd", pattern="[Rr]heumatoid arthritis", max_depth=1)
                    ],
                ))

        res = self.search_icd(pattern="[Rr]heumatoid arthritis", max_depth=1)
        # assert that the string contains "M05.5"
        assert "M05.5" in res

        history.append(ChatMessage.function("search_icd", res, tc1.id))

        history.append(
                ChatMessage.assistant(
                    content=None,
                    # use a walrus operator to save a reference to the tool call here...
                    tool_calls=[
                        tc2 := ToolCall.from_function("get_children", code="M05.5")
                    ],
                ))

        res = self.get_children(code="M05.5")
        # assert that the string contains "M05.53"
        assert "M05.53" in res

        history.append(ChatMessage.function("get_children", res, tc2.id))

        history.append(
                ChatMessage.assistant(
                    content=None,
                    # use a walrus operator to save a reference to the tool call here...
                    tool_calls=[
                        tc3 := ToolCall.from_function("get_children", code="M05.53")
                    ],
                ))

        res = self.get_children(code="M05.53")
        # assert that the string contains "M05.539"
        assert "M05.539" in res

        history.append(ChatMessage.function("get_children", res, tc3.id))

        # we'll also look for children of M05.539, but there will be none. Still, we'll add it to the history
        history.append(
                ChatMessage.assistant(
                    content=None,
                    # use a walrus operator to save a reference to the tool call here...
                    tool_calls=[
                        tc4 := ToolCall.from_function("get_children", code="M05.539")
                    ],
                ))

        res = self.get_children(code="M05.539")
        # assert that the string is ""
        assert res == "No children."

        history.append(ChatMessage.function("get_children", res, tc4.id))

        history.append(ChatMessage.assistant("Yes, there are multiple codes for rheumatoid arthritis in the wrist with polyneuropathy. Since you didn't mention left or right, M05.539 'Rheumatoid polyneurop w rheumatoid arthritis of unsp wrist' appears to be the most appropriate."))
        return history