oneilsh commited on
Commit
c0aa211
·
1 Parent(s): cd31ef7

trying huggingface spaces...

Browse files
Makefile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: install run
2
+
3
+ install:
4
+ poetry install --no-root
5
+
6
+ run:
7
+ poetry run streamlit run app.py
8
+
9
+ build-tree:
10
+ poetry run python -c 'from main import build_tree_cache; build_tree_cache()'
11
+
12
+ test-search:
13
+ poetry run python -c 'from main import test_search; test_search()'
14
+
15
+ test-semantic-search:
16
+ poetry run python -c 'from main import test_semantic_search; test_semantic_search()'
17
+
18
+ fetch-data:
19
+ mkdir -p data
20
+ mkdir -p data/addendum
21
+ mkdir -p data/code_descriptions
22
+ mkdir -p data/poa_exempt_codes
23
+ mkdir -p data/table_index
24
+
25
+ @if [ ! -f data/code_descriptions/code_descriptions.zip ]; then \
26
+ wget https://ftp.cdc.gov/pub/Health_Statistics/NCHS/Publications/ICD10CM/2025/ICD10-CM%20Code%20Descriptions%202025.zip -O data/code_descriptions/code_descriptions.zip; \
27
+ unzip data/code_descriptions/code_descriptions.zip -d data/code_descriptions; \
28
+ fi
29
+
30
+ @if [ ! -f data/addendum/addendum.zip ]; then \
31
+ wget https://ftp.cdc.gov/pub/Health_Statistics/NCHS/Publications/ICD10CM/2025/ICD10cm-addendum-2025.zip -O data/addendum/addendum.zip; \
32
+ unzip data/addendum/addendum.zip -d data/addendum; \
33
+ fi
34
+
35
+ @if [ ! -f data/table_index/table_index.zip ]; then \
36
+ wget https://ftp.cdc.gov/pub/Health_Statistics/NCHS/Publications/ICD10CM/2025/icd10cm-table-index-2025.zip -O data/table_index/table_index.zip; \
37
+ unzip data/table_index/table_index.zip -d data/table_index; \
38
+ fi
39
+
40
+ @if [ ! -f data/poa_exempt_codes/poa_exempt_codes.zip ]; then \
41
+ wget https://ftp.cdc.gov/pub/Health_Statistics/NCHS/Publications/ICD10CM/2025/POAexemptCodesFY25.zip -O data/poa_exempt_codes/poa_exempt_codes.zip; \
42
+ unzip data/poa_exempt_codes/poa_exempt_codes.zip -d data/poa_exempt_codes; \
43
+ fi
44
+
45
+ @if [ ! -f data/guidelines.pdf ]; then \
46
+ wget https://ftp.cdc.gov/pub/Health_Statistics/NCHS/Publications/ICD10CM/2025/icd-10-cm-FY25-guidelines-october%20-2024.pdf -O data/guidelines.pdf; \
47
+ fi
README.md CHANGED
@@ -1,12 +1,4 @@
1
- ---
2
- title: Icdbot
3
- emoji: 📊
4
- colorFrom: indigo
5
- colorTo: pink
6
- sdk: streamlit
7
- sdk_version: 1.44.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # ICDBot
2
+
3
+ ICDBot is a web-based agent with access to the ICD-10-CM vocabulary via search, hierarchically, and other ways. It is currently considered experimental.
 
 
 
 
 
 
 
4
 
 
__pycache__/agents.cpython-312.pyc ADDED
Binary file (7.27 kB). View file
 
__pycache__/agents.cpython-313.pyc ADDED
Binary file (7.54 kB). View file
 
__pycache__/globals.cpython-312.pyc ADDED
Binary file (212 Bytes). View file
 
__pycache__/globals.cpython-313.pyc ADDED
Binary file (212 Bytes). View file
 
__pycache__/icdnode.cpython-312.pyc ADDED
Binary file (6.48 kB). View file
 
__pycache__/icdnode.cpython-313.pyc ADDED
Binary file (6.61 kB). View file
 
__pycache__/icdtree.cpython-312.pyc ADDED
Binary file (7.62 kB). View file
 
__pycache__/icdtree.cpython-313.pyc ADDED
Binary file (10.8 kB). View file
 
__pycache__/icdutils.cpython-312.pyc ADDED
Binary file (2 kB). View file
 
__pycache__/main.cpython-312.pyc ADDED
Binary file (2.16 kB). View file
 
__pycache__/main.cpython-313.pyc ADDED
Binary file (2.09 kB). View file
 
agents.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from kani_utils.base_kanis import StreamlitKani
2
+ from kani import AIParam, ai_function, ChatMessage, ToolCall
3
+ import streamlit as st
4
+
5
+ from typing import Annotated
6
+ import pandas as pd
7
+ import re
8
+ import yaml
9
+
10
+ from icdtree import ICDTree, get_tree
11
+
12
+
13
+ class ICDBot(StreamlitKani):
14
+ # Be sure to override the __init__ method to pass any parameters to the superclass
15
+ def __init__(self, *args, **kwargs):
16
+ #kwargs['chat_history'] = self._build_fewshot()
17
+
18
+ kwargs['system_prompt'] = """
19
+ You are ICDBot, a chatbot with programmatic access to the ICD-10-CM vocabulary. You can search for ICD-10 codes via and query for terms' parent and children nodes in the hierarchy.
20
+
21
+ * Use the `get_context` function to find the most specific code relevant to the user's query by navigating up and down the term hierarchy. For example, if the search for "fracture humerus" returns S42.2, you might use `get_context` to find the parent code S42.2 and its children codes S42.20, S42.21, ..., to identify the most specific code that matches the user's query.
22
+ """.strip()
23
+
24
+ super().__init__(*args, **kwargs)
25
+
26
+ # Define avatars for the agent and user
27
+ # Can be URLs or emojis
28
+ self.avatar = "🤖"
29
+ self.user_avatar = "👤"
30
+
31
+ # The name and greeting are shown at the start of the chat
32
+ # The greeting is not known to the LLM, it serves as a prompt for the user
33
+ self.name = "ICDBot"
34
+ self.greeting = "Hello, I'm ICDBot. Nice to meet you!"
35
+
36
+ # The description is shown in the sidebar and provides more information about the agent
37
+ self.description = "An agent with access to the ICD-10-CM vocabulary."
38
+
39
+ self.icd10_tree = get_tree()
40
+
41
+
42
+ @ai_function()
43
+ def semantic_search(self,
44
+ query: Annotated[str, AIParam(desc="Query string to search for.")],
45
+ limit: Annotated[int, AIParam(desc="Maximum number of results to return. Default to 10 unless necessary.")]=10,
46
+ offset: Annotated[int, AIParam(desc="Number of results to skip.")]=0):
47
+
48
+ """Search for ICD-10 codes similar to the query using an embedding-search strategy."""
49
+
50
+ search_results = self.icd10_tree.semantic_search(query, limit=limit, offset=offset)
51
+
52
+ results_strings = [f"{node.code}: {node.short_desc}; Valid: {node.valid}" for node, _ in search_results]
53
+ result = yaml.dump({"Matches": results_strings})
54
+ return result
55
+
56
+ @ai_function()
57
+ def think(self, thought: Annotated[str, AIParam(desc="The thought to think.")]):
58
+ """Use this function to record internal thoughts or notes, and for planning complex actions. You may call this function more than once in succession for complex plans."""
59
+ return None
60
+
61
+ @ai_function()
62
+ def get_context(self,
63
+ code: Annotated[str, AIParam(desc="ICD-10 code to retrieve context for")]):
64
+ """Given an ICD-10 code, return information about its parent, children, and siblings."""
65
+ node = self.icd10_tree.get_node_by_code(code)
66
+ if node:
67
+ parent = node.parent
68
+ children = node.children
69
+ siblings = parent.children if parent else []
70
+ parent_string = f"{parent.code}: {parent.short_desc}; Valid: {parent.valid}" if parent else "No parent"
71
+ children_strings = [f"{child.code}: {child.short_desc}; Valid: {child.valid}" for child in children]
72
+ self_string = f"{node.code}: {node.short_desc}; Valid: {node.valid}"
73
+ siblings_strings = [f"{sibling.code}: {sibling.short_desc}; Valid: {sibling.valid}" for sibling in siblings]
74
+ # format the result as yaml
75
+ result = yaml.dump({"Node": self_string, "Parent": parent_string, "Siblings": siblings_strings, "Children": children_strings})
76
+
77
+ return result
78
+ else:
79
+ return f"Code {code} not found."
80
+
81
+
82
+
83
+ def _build_fewshot(self):
84
+
85
+ history = [ChatMessage.user("Is there a condition that corresponds to rheumatoid arthritis in a patient's wrist, when they also have neuropathy?")]
86
+ history.append(
87
+ ChatMessage.assistant(
88
+ content=None,
89
+ # use a walrus operator to save a reference to the tool call here...
90
+ tool_calls=[
91
+ tc1 := ToolCall.from_function("search_icd", pattern="[Rr]heumatoid arthritis", max_depth=1)
92
+ ],
93
+ ))
94
+
95
+ res = self.search_icd(pattern="[Rr]heumatoid arthritis", max_depth=1)
96
+ # assert that the string contains "M05.5"
97
+ assert "M05.5" in res
98
+
99
+ history.append(ChatMessage.function("search_icd", res, tc1.id))
100
+
101
+ history.append(
102
+ ChatMessage.assistant(
103
+ content=None,
104
+ # use a walrus operator to save a reference to the tool call here...
105
+ tool_calls=[
106
+ tc2 := ToolCall.from_function("get_children", code="M05.5")
107
+ ],
108
+ ))
109
+
110
+ res = self.get_children(code="M05.5")
111
+ # assert that the string contains "M05.53"
112
+ assert "M05.53" in res
113
+
114
+ history.append(ChatMessage.function("get_children", res, tc2.id))
115
+
116
+ history.append(
117
+ ChatMessage.assistant(
118
+ content=None,
119
+ # use a walrus operator to save a reference to the tool call here...
120
+ tool_calls=[
121
+ tc3 := ToolCall.from_function("get_children", code="M05.53")
122
+ ],
123
+ ))
124
+
125
+ res = self.get_children(code="M05.53")
126
+ # assert that the string contains "M05.539"
127
+ assert "M05.539" in res
128
+
129
+ history.append(ChatMessage.function("get_children", res, tc3.id))
130
+
131
+ # we'll also look for children of M05.539, but there will be none. Still, we'll add it to the history
132
+ history.append(
133
+ ChatMessage.assistant(
134
+ content=None,
135
+ # use a walrus operator to save a reference to the tool call here...
136
+ tool_calls=[
137
+ tc4 := ToolCall.from_function("get_children", code="M05.539")
138
+ ],
139
+ ))
140
+
141
+ res = self.get_children(code="M05.539")
142
+ # assert that the string is ""
143
+ assert res == "No children."
144
+
145
+ history.append(ChatMessage.function("get_children", res, tc4.id))
146
+
147
+ history.append(ChatMessage.assistant("Yes, there are multiple codes for rheumatoid arthritis in the wrist with polyneuropathy. Since you didn't mention left or right, M05.539 'Rheumatoid polyneurop w rheumatoid arthritis of unsp wrist' appears to be the most appropriate."))
148
+ return history
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Example usage of StreamlitKani
2
+
3
+ ########################
4
+ ##### 0 - load libs
5
+ ########################
6
+
7
+ # kani_streamlit imports
8
+ import kani_utils.kani_streamlit_server as ks
9
+
10
+ # for reading API keys from .env file
11
+ import os
12
+ import dotenv # pip install python-dotenv
13
+
14
+ # kani imports
15
+ from kani.engines.openai import OpenAIEngine
16
+
17
+ # load app-defined agents
18
+ from agents import ICDBot
19
+
20
+ # read API keys .env file (e.g. set OPENAI_API_KEY=.... in .env and gitignore .env)
21
+ dotenv.load_dotenv()
22
+
23
+ ########################
24
+ ##### 1 - Configuration
25
+ ########################
26
+
27
+ # initialize the application and set some page settings
28
+ # parameters here are passed to streamlit.set_page_config,
29
+ # see more at https://docs.streamlit.io/library/api-reference/utilities/st.set_page_config
30
+ # this function MUST be run first
31
+ ks.initialize_app_config(
32
+ show_function_calls = True,
33
+ page_title = "ICDBot",
34
+ page_icon = "🤖", # can also be a URL
35
+ initial_sidebar_state = "expanded", # or "expanded"
36
+ menu_items = {
37
+ "Get Help": "https://github.com/.../issues",
38
+ "Report a Bug": "https://github.com/.../issues",
39
+ "About": "An AI agent with access to the ICD-10-CM vocabulary.",
40
+ }
41
+ )
42
+
43
+
44
+ ########################
45
+ ##### 2 - Define Agents
46
+ ########################
47
+
48
+ # define an engine to use (see Kani documentation for more info)
49
+ engine = OpenAIEngine(os.environ["OPENAI_API_KEY"], model="gpt-4o")
50
+ # engine = OpenAIEngine(os.environ["OPENAI_API_KEY"], model="o1")
51
+
52
+ # We also have to define a function that returns a dictionary of agents to serve
53
+ # Agents are keyed by their name, which is what the user will see in the UI
54
+ def get_agents():
55
+ return {
56
+ "ICDBot": ICDBot(engine, prompt_tokens_cost = 0.005, completion_tokens_cost = 0.015),
57
+ }
58
+
59
+
60
+ # tell the app to use that function to create agents when needed
61
+ ks.set_app_agents(get_agents)
62
+
63
+
64
+ ########################
65
+ ##### 3 - Serve App
66
+ ########################
67
+
68
+ ks.serve_app()
globals.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # for global stuff if needed
2
+ total_nodes = None
3
+ num_embedded = 0
icdnode.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from sentence_transformers import SentenceTransformer
3
+ import globals
4
+ import sys
5
+
6
+ class ICDNode:
7
+ def __init__(self, order, code, valid, short_desc, long_desc):
8
+ self.order = order
9
+
10
+ code = code.replace('.', '')
11
+ if len(code) > 3:
12
+ self.code = code[:3] + '.' + code[3:]
13
+ else:
14
+ self.code = code
15
+
16
+ self.valid = valid
17
+ self.short_desc = short_desc
18
+ self.long_desc = long_desc
19
+ self.children = []
20
+ self.parent = None
21
+ self.embedding = None
22
+ self.level = 0
23
+
24
+ def _set_levels(self):
25
+ """The level of this node is the level of it's parent (if it's not none). The level of children is one more than the level of the parent."""
26
+ if self.parent:
27
+ self.level = self.parent.level + 1
28
+ for child in self.children:
29
+ child._set_levels()
30
+
31
+ def get_parent(self):
32
+ """Return the parent node of this node."""
33
+ return self.parent
34
+
35
+ def get_children(self):
36
+ """Return the immediate children nodes of this node."""
37
+ return self.children
38
+
39
+ def get_descendants(self):
40
+ """Return the children nodes of this node and all subnodes as a list, in a depth-first order."""
41
+ children = []
42
+ for child in self.children:
43
+ children.append(child)
44
+ children.extend(child.get_descendants())
45
+ return children
46
+
47
+ def get_ancestors(self):
48
+ """Return the parent nodes of this node and all super nodes as a list."""
49
+ ancestors = []
50
+ parent = self.parent
51
+ while parent:
52
+ ancestors.append(parent)
53
+ parent = parent.parent
54
+ return ancestors
55
+
56
+ def _add_embeddings(self, embedding_model, prefix = ""):
57
+ """Add embeddings to this node and all its descendants."""
58
+ self.embedding = embedding_model.encode([prefix + self.short_desc])[0]
59
+ globals.num_embedded += 1
60
+ percent_complete = round(100 * (globals.num_embedded / globals.total_nodes), 2)
61
+ sys.stderr.write(f"({percent_complete}%) " + str(self) + "\n")
62
+ for child in self.children:
63
+ child._add_embeddings(embedding_model, prefix)
64
+
65
+ def _sort_children(self):
66
+ """Sort the children nodes by their code."""
67
+ self.children.sort(key=lambda x: x.code)
68
+
69
+ def _add_child(self, child_node):
70
+ """Add a child node to this node."""
71
+ self.children.append(child_node)
72
+ child_node.parent = self
73
+
74
+ def regex_search(self, pattern, max_depth=None, current_depth=0, valid = None):
75
+ """Returns a list of node objects whose short_desc, long_desc,
76
+ or code matches the regex pattern. If valid is set to True, only
77
+ return valid nodes; if False, only return invalid nodes. If None,
78
+ return all matching nodes. max_depth is the maximum depth to search."""
79
+
80
+ if max_depth is None:
81
+ max_depth = float('inf')
82
+
83
+ # only return the node if it matches the pattern and is within the max_depth in the hierarchy
84
+ results = []
85
+ if re.search(pattern, self.short_desc, re.IGNORECASE) or re.search(pattern, self.long_desc, re.IGNORECASE) or re.search(pattern, self.code, re.IGNORECASE):
86
+ if valid is None or self.valid == valid:
87
+ results.append(self)
88
+ if current_depth < max_depth:
89
+ for child in self.children:
90
+ if valid is None or child.valid == valid:
91
+ results.extend(child.regex_search(pattern, max_depth = max_depth, current_depth = current_depth + 1, valid = valid))
92
+
93
+ return results
94
+
95
+
96
+ def __repr__(self):
97
+ embed_str_rep = "[" + ", ".join([str(round(x, 5)) for x in self.embedding[:4]]) + ", ...]" if hasattr(self, 'embedding') else None
98
+ return f"Code: {self.code}\tLevel: {self.level}\tValid: {self.valid}\tDesc: {self.short_desc}\tVec: {embed_str_rep}"
icdtree.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from icdnode import ICDNode
3
+ import globals
4
+ import sys
5
+ import faiss
6
+ import numpy as np
7
+ import os
8
+ import zipfile
9
+ import requests
10
+ import shutil
11
+
12
+ def get_tree(file_path = "data/code_descriptions/icd10cm-order-2025.txt", st_embedding_model = 'intfloat/e5-small-v2'):
13
+ # ok, so we're going to cache the result of the tree building in a pickle file, with a name determined based on the file_path and st_embedding_model
14
+ # if the pickle file exists, we'll load it and return it
15
+ # if it doesn't exist, we'll build the tree, save it to the pickle file, and return it
16
+
17
+ # first, let's determine the pickle file name, using a hash of the file_path and st_embedding_model
18
+ import hashlib
19
+ m = hashlib.md5()
20
+ m.update(file_path.encode('utf-8'))
21
+ m.update(st_embedding_model.encode('utf-8'))
22
+ pickle_file_name = "data/" + m.hexdigest() + ".pkl"
23
+
24
+ # now, let's try to load the pickle file
25
+ import pickle
26
+ icd_tree = None
27
+ try:
28
+ with open(pickle_file_name, 'rb') as f:
29
+ icd_tree = pickle.load(f)
30
+ sys.stderr.write("Loaded tree from pickle file\n")
31
+
32
+ except FileNotFoundError:
33
+ try:
34
+ url = "https://ftp.cdc.gov/pub/Health_Statistics/NCHS/Publications/ICD10CM/2025/ICD10-CM%20Code%20Descriptions%202025.zip"
35
+ zip_file_path = "data/code_descriptions/code_descriptions.zip"
36
+
37
+ # Ensure the directory exists
38
+ os.makedirs("data/code_descriptions", exist_ok=True)
39
+
40
+ # Download the file if it doesn't exist
41
+ if not os.path.exists(zip_file_path):
42
+ try:
43
+ sys.stderr.write("Downloading ICD data...\n")
44
+ response = requests.get(url)
45
+ if response.status_code != 200:
46
+ sys.stderr.write(f"Failed to download file: {response.status_code}\n")
47
+ sys.exit(1)
48
+ with open(zip_file_path, 'wb') as f:
49
+ f.write(response.content)
50
+ except Exception as e:
51
+ sys.stderr.write(f"Error downloading file: {e}\n")
52
+ sys.exit(1)
53
+
54
+ # Unzip the file
55
+ try:
56
+ sys.stderr.write("Unzipping ICD data...\n")
57
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
58
+ zip_ref.extractall("data/code_descriptions")
59
+ except zipfile.BadZipFile:
60
+ sys.stderr.write("Error: The downloaded file is not a valid ZIP file.\n")
61
+ sys.exit(1)
62
+
63
+ # Copy the file to the desired location
64
+ file_path = "data/code_descriptions/icd10cm-order-2025.txt"
65
+ if not os.path.exists(file_path):
66
+ try:
67
+ sys.stderr.write("Copying ICD data...\n")
68
+ shutil.copy("data/code_descriptions/icd10cm-order-2025.txt", file_path)
69
+ except Exception as e:
70
+ sys.stderr.write(f"Error copying file: {e}\n")
71
+ sys.exit(1)
72
+ except Exception as e:
73
+ sys.stderr.write("Error downloading or unzipping the file: " + str(e))
74
+ return None
75
+
76
+ # if the file doesn't exist, we'll build the tree
77
+ sys.stderr.write("Building tree\n")
78
+ icd_tree = ICDTree(file_path, st_embedding_model)
79
+
80
+ # and save it to the pickle file
81
+ with open(pickle_file_name, 'wb') as f:
82
+ pickle.dump(icd_tree, f)
83
+
84
+ icd_tree.build_faiss_index()
85
+ return icd_tree
86
+
87
+ class ICDTree:
88
+ def __init__(self, path, st_embedding_model):
89
+ self.nodes_dict = {} # Dictionary to hold all nodes by code without dots
90
+ self.roots = [] # List to hold root nodes
91
+ self.build_tree(path)
92
+ globals.total_nodes = len(self.nodes_dict)
93
+
94
+ # e5-small-v2 seems to work well on synonym mapping (and not gated unlike the jina model): https://arxiv.org/html/2401.01943v2
95
+ # intfloat's e5 models all require a "query: " prefix for semantic similarity tasks: see e.g. https://huggingface.co/intfloat/e5-small-v2
96
+
97
+ self.embedding_model = SentenceTransformer(st_embedding_model)
98
+ prefix = ""
99
+ if "intfloat" in st_embedding_model and "e5" in st_embedding_model:
100
+ prefix = "query: "
101
+
102
+ for root in self.roots:
103
+ root._add_embeddings(self.embedding_model, prefix)
104
+
105
+
106
+ def build_faiss_index(self):
107
+ # we need to keep a faiss index for the embeddings
108
+ # IndexFlatL2 seems to work slightly better than IndexFlatIP, at least for the one query I tried ('hurt tummy')
109
+ self.index = faiss.IndexFlatL2(self.embedding_model.get_sentence_embedding_dimension())
110
+ #self.index = faiss.IndexFlatIP(self.embedding_model.get_sentence_embedding_dimension())
111
+ embeddings = [node.embedding for node in self.nodes_dict.values()]
112
+
113
+ self.index.add(np.array(embeddings))
114
+
115
+
116
+ def semantic_search(self, query, limit = 10, offset = 0):
117
+ """Search for nodes similar to the query using the sentence transformer model."""
118
+ query_embedding = self.embedding_model.encode([query])[0]
119
+
120
+
121
+ # search the faiss index
122
+ D, I = self.index.search(np.array([query_embedding]), limit + offset)
123
+ results = []
124
+ for i in range(limit + offset):
125
+ node = list(self.nodes_dict.values())[I[0][i]]
126
+ results.append((node, float(D[0][i])))
127
+
128
+ # sort by the node's level
129
+ # results.sort(key=lambda x: x[0].level)
130
+ # sort results by distance
131
+ results.sort(key=lambda x: x[1])
132
+
133
+ return results[offset:]
134
+
135
+
136
+ # def get_descendants_yaml(self, code):
137
+ # """Returns a YAML string representing the node and its descendants, with indentation."""
138
+ # root = self.get_node_by_code(code)
139
+ # if root:
140
+ # return root.get_descendants_yaml()
141
+ # else:
142
+ # return None
143
+
144
+
145
+ def regex_search(self, pattern, max_depth=None, valid = None):
146
+ """Returns a list of node objects whose short_desc, long_desc, or
147
+ code matches the regex pattern. max_depth is the maximum depth to
148
+ search in the hierarchy; if None, search all levels.
149
+
150
+ If valid is set to 1, only return valid nodes; if 0, only return invalid nodes. If None, return all matching nodes.
151
+ """
152
+
153
+ results = []
154
+ for root in self.roots:
155
+ results.extend(root.regex_search(pattern, max_depth = max_depth, valid = valid))
156
+ return results
157
+
158
+ def build_tree(self, file_path):
159
+ # First pass: Read the file and create all nodes
160
+ with open(file_path, 'r') as f:
161
+ for line in f:
162
+ order = line[0:5].strip()
163
+ code = line[6:13].strip()
164
+ valid = line[14:15].strip()
165
+ short_desc = line[16:76].strip()
166
+ long_desc = line[77:].strip()
167
+
168
+ node = ICDNode(order, code, valid, short_desc, long_desc)
169
+ code_key = node.code.replace('.', '') # Key without dots for consistent lookup
170
+ self.nodes_dict[code_key] = node
171
+
172
+ # Second pass: Set parent-child relationships
173
+ for code_key, node in self.nodes_dict.items():
174
+ parent_code_key = self.find_parent_code(code_key)
175
+ if parent_code_key:
176
+ parent_node = self.nodes_dict[parent_code_key]
177
+ parent_node._add_child(node)
178
+ else:
179
+ # No parent found, this node is a root
180
+ self.roots.append(node)
181
+
182
+ # Sort children of all nodes
183
+ for code_key, node in self.nodes_dict.items():
184
+ node._sort_children()
185
+
186
+ # Set levels for all nodes
187
+ for root in self.roots:
188
+ root._set_levels()
189
+
190
+ def find_parent_code(self, code_key):
191
+ """Find the parent code by progressively stripping characters from the end."""
192
+ for i in range(len(code_key) - 1, 0, -1):
193
+ parent_code = code_key[:i]
194
+ if parent_code in self.nodes_dict:
195
+ return parent_code
196
+ return None
197
+
198
+ def print_tree(self):
199
+ """Print all root nodes and their subtrees."""
200
+ for root in self.roots:
201
+ root.print_node()
202
+
203
+ def get_node_by_code(self, code):
204
+ """Retrieve a node by its code (with or without dots)."""
205
+ code_key = code.replace('.', '')
206
+ return self.nodes_dict.get(code_key)
main.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from icdtree import ICDTree, get_tree
2
+
3
+ # run as e.g.
4
+ # poetry run python -c 'import main; main.test_semantic_search()'
5
+
6
+ def build_tree_cache():
7
+ icd_tree = get_tree()
8
+ print("Tree built and saved.")
9
+
10
+
11
+ def test_semantic_search():
12
+ icd10_tree = get_tree()
13
+ #icd10_tree.build_faiss_index()
14
+
15
+ # search for the term "Torus"
16
+ search_results = icd10_tree.semantic_search("pneumonia", limit=10)
17
+ for result in search_results:
18
+ print(result)
19
+
20
+ print("getting more")
21
+ search_results = icd10_tree.semantic_search("pneumonia", limit=10, offset=10)
22
+ for result in search_results:
23
+ print(result)
24
+
25
+
26
+ def test_search():
27
+ icd10_tree = get_tree()
28
+
29
+ search_nodes = icd10_tree.regex_search(r'Torus', max_depth=5)
30
+ if len(search_nodes) > 0:
31
+ # get the highest-level nodes (closest to 0)
32
+ search_nodes.sort(key=lambda x: x.level)
33
+ stack = [node for node in search_nodes if node.level == search_nodes[0].level]
34
+
35
+ # do a quick depth first search to print all nodes in order
36
+ while len(stack) > 0:
37
+ node = stack.pop()
38
+ print(node)
39
+ stack.extend(node.get_children())
40
+
41
+
42
+ if __name__ == '__main__':
43
+ print("Try running this file like `python -c 'import main; main.build_tree_cache()'` to build the tree cache.")
44
+
45
+ # s42 = icd10_tree.get_node_by_code('S42.001B')
46
+ # # index 0 is the yaml string, index 1 is the level
47
+ # print(s42.get_ancestors_yaml()[0])
48
+
49
+ # s42 = icd10_tree.get_node_by_code('S42')
50
+ # print(s42)
51
+ # for child in s42.get_children():
52
+ # print(child)
53
+ # for grandchild in child.get_children():
54
+ # print(grandchild)
55
+ # print(s42.get_descendants_yaml())
56
+
57
+ # res = icd10_tree.regex_search('.*', max_depth=0)
58
+ # for node in res:
59
+ # print(node)
60
+
61
+ # print("#######")
62
+
63
+ # focal = icd10_tree.get_node_by_code('M05.5')
64
+ # children = focal.get_children()
65
+ # for node in children:
66
+ # print(node)
67
+
68
+ # print("#######")
69
+
70
+ # focal = icd10_tree.get_node_by_code('M05.53')
71
+ # children = focal.get_children()
72
+ # for node in children:
73
+ # print(node)
74
+
75
+ # print("#######")
76
+
77
+ # focal = icd10_tree.get_node_by_code('M05.539')
78
+ # children = focal.get_children()
79
+ # for node in children:
80
+ # print(node)
81
+
82
+
oldcode.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Not really useful, but don't want to delete it yet
2
+ def get_ancestors_yaml(self):
3
+ """Returns a YAML string representing the node and its ancestors, with indentation. For example:
4
+ - A00 Cholera:
5
+ - A00.0 Cholera due to Vibrio cholerae 01, biovar cholerae
6
+ - A00.00 Cholera due to Vibrio cholerae 01, biovar cholerae, cholera gravis
7
+ ..."""
8
+ my_yaml = f"- {self.code} {self.short_desc}:\n"
9
+ if self.parent:
10
+ parent_yaml, parent_level = self.parent.get_ancestors_yaml()
11
+ my_level = parent_level + 1
12
+ my_yaml = ' ' * (my_level) + my_yaml
13
+ yaml = parent_yaml + my_yaml
14
+ else:
15
+ yaml = my_yaml
16
+ my_level = 0
17
+ return yaml, my_level
18
+
19
+ def get_descendants_yaml(self):
20
+ """Returns a YAML string representing the node and its descendants, with indentation. For example:
21
+ - A00 Cholera:
22
+ - A00.0 Cholera due to Vibrio cholerae 01, biovar cholerae
23
+ - A00.1 Cholera due to Vibrio cholerae 01, biovar eltor
24
+ - A00.9 Cholera, unspecified
25
+ ..."""
26
+ yaml = f"- {self.code} {self.short_desc}:\n"
27
+ for child in self.children:
28
+ child_yaml = child.get_descendants_yaml()
29
+ child_yaml = child_yaml.replace('\n', '\n ')
30
+ yaml += f" {child_yaml}"
31
+ return yaml
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "icdbot"
3
+ version = "0.1.0"
4
+ description = "An AI agent for intelligent querying of the ICD10CM vocabulary."
5
+ authors = ["Shawn T O'Neil <[email protected]>"]
6
+ license = "MIT"
7
+ readme = "README.md"
8
+ package-mode = false
9
+
10
+ [tool.poetry.dependencies]
11
+ python = "^3.10"
12
+ kani-utils = {git = "https://github.com/oneilsh/kani-utils.git"}
13
+ pandas = "^2.2.3"
14
+ kani = "^1.2.2"
15
+ transformers = "^4.48.0"
16
+ sentence-transformers = "^3.3.1"
17
+ peft = "^0.14.0"
18
+ faiss-cpu = "^1.9.0.post1"
19
+ numpy = "^2.2.2"
20
+ pyyaml = "^6.0.2"
21
+ streamlit = "^1.41.1"
22
+ requests = "^2.32.3"
23
+
24
+
25
+ [build-system]
26
+ requires = ["poetry-core"]
27
+ build-backend = "poetry.core.masonry.api"