File size: 8,462 Bytes
c0aa211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from sentence_transformers import SentenceTransformer
from icdnode import ICDNode
import globals
import sys
import faiss
import numpy as np
import os
import zipfile
import requests
import shutil

def get_tree(file_path = "data/code_descriptions/icd10cm-order-2025.txt", st_embedding_model = 'intfloat/e5-small-v2'):
    # ok, so we're going to cache the result of the tree building in a pickle file, with a name determined based on the file_path and st_embedding_model
    # if the pickle file exists, we'll load it and return it
    # if it doesn't exist, we'll build the tree, save it to the pickle file, and return it

    # first, let's determine the pickle file name, using a hash of the file_path and st_embedding_model
    import hashlib
    m = hashlib.md5()
    m.update(file_path.encode('utf-8'))
    m.update(st_embedding_model.encode('utf-8'))
    pickle_file_name = "data/" + m.hexdigest() + ".pkl"

    # now, let's try to load the pickle file
    import pickle
    icd_tree = None
    try:
        with open(pickle_file_name, 'rb') as f:
            icd_tree = pickle.load(f)
            sys.stderr.write("Loaded tree from pickle file\n")

    except FileNotFoundError:
        try:
            url = "https://ftp.cdc.gov/pub/Health_Statistics/NCHS/Publications/ICD10CM/2025/ICD10-CM%20Code%20Descriptions%202025.zip"
            zip_file_path = "data/code_descriptions/code_descriptions.zip"

            # Ensure the directory exists
            os.makedirs("data/code_descriptions", exist_ok=True)

            # Download the file if it doesn't exist
            if not os.path.exists(zip_file_path):
                try:
                    sys.stderr.write("Downloading ICD data...\n")
                    response = requests.get(url)
                    if response.status_code != 200:
                        sys.stderr.write(f"Failed to download file: {response.status_code}\n")
                        sys.exit(1)
                    with open(zip_file_path, 'wb') as f:
                        f.write(response.content)
                except Exception as e:
                    sys.stderr.write(f"Error downloading file: {e}\n")
                    sys.exit(1)

            # Unzip the file
            try:
                sys.stderr.write("Unzipping ICD data...\n")
                with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
                    zip_ref.extractall("data/code_descriptions")
            except zipfile.BadZipFile:
                sys.stderr.write("Error: The downloaded file is not a valid ZIP file.\n")
                sys.exit(1)

            # Copy the file to the desired location
            file_path = "data/code_descriptions/icd10cm-order-2025.txt"
            if not os.path.exists(file_path):
                try:
                    sys.stderr.write("Copying ICD data...\n")
                    shutil.copy("data/code_descriptions/icd10cm-order-2025.txt", file_path)
                except Exception as e:
                    sys.stderr.write(f"Error copying file: {e}\n")
                    sys.exit(1)
        except Exception as e:
            sys.stderr.write("Error downloading or unzipping the file: " + str(e))
            return None

        # if the file doesn't exist, we'll build the tree
        sys.stderr.write("Building tree\n")
        icd_tree = ICDTree(file_path, st_embedding_model)

        # and save it to the pickle file
        with open(pickle_file_name, 'wb') as f:
            pickle.dump(icd_tree, f)
    
    icd_tree.build_faiss_index()
    return icd_tree

class ICDTree:
    def __init__(self, path, st_embedding_model):
        self.nodes_dict = {}  # Dictionary to hold all nodes by code without dots
        self.roots = []       # List to hold root nodes
        self.build_tree(path)
        globals.total_nodes = len(self.nodes_dict)

        # e5-small-v2 seems to work well on synonym mapping (and not gated unlike the jina model): https://arxiv.org/html/2401.01943v2
        # intfloat's e5 models all require a "query: " prefix for semantic similarity tasks: see e.g. https://huggingface.co/intfloat/e5-small-v2

        self.embedding_model = SentenceTransformer(st_embedding_model)
        prefix = ""
        if "intfloat" in st_embedding_model and "e5" in st_embedding_model:
            prefix = "query: "

        for root in self.roots:
            root._add_embeddings(self.embedding_model, prefix)


    def build_faiss_index(self):
        # we need to keep a faiss index for the embeddings
        # IndexFlatL2 seems to work slightly better than IndexFlatIP, at least for the one query I tried ('hurt tummy')
        self.index = faiss.IndexFlatL2(self.embedding_model.get_sentence_embedding_dimension())
        #self.index = faiss.IndexFlatIP(self.embedding_model.get_sentence_embedding_dimension())
        embeddings = [node.embedding for node in self.nodes_dict.values()]
       
        self.index.add(np.array(embeddings))


    def semantic_search(self, query, limit = 10, offset = 0):
        """Search for nodes similar to the query using the sentence transformer model."""
        query_embedding = self.embedding_model.encode([query])[0]


        # search the faiss index
        D, I = self.index.search(np.array([query_embedding]), limit + offset)
        results = []
        for i in range(limit + offset):
            node = list(self.nodes_dict.values())[I[0][i]]
            results.append((node, float(D[0][i])))

        # sort by the node's level
        # results.sort(key=lambda x: x[0].level)
        # sort results by distance
        results.sort(key=lambda x: x[1])

        return results[offset:]


    # def get_descendants_yaml(self, code):
    #     """Returns a YAML string representing the node and its descendants, with indentation."""
    #     root = self.get_node_by_code(code)
    #     if root:
    #         return root.get_descendants_yaml()
    #     else:
    #         return None
        

    def regex_search(self, pattern, max_depth=None, valid = None):
        """Returns a list of node objects whose short_desc, long_desc, or 
        code matches the regex pattern. max_depth is the maximum depth to 
        search in the hierarchy; if None, search all levels.
        
        If valid is set to 1, only return valid nodes; if 0, only return invalid nodes. If None, return all matching nodes.
        """

        results = []
        for root in self.roots:
            results.extend(root.regex_search(pattern, max_depth = max_depth, valid = valid))
        return results

    def build_tree(self, file_path):
        # First pass: Read the file and create all nodes
        with open(file_path, 'r') as f:
            for line in f:
                order = line[0:5].strip()
                code = line[6:13].strip()
                valid = line[14:15].strip()
                short_desc = line[16:76].strip()
                long_desc = line[77:].strip()

                node = ICDNode(order, code, valid, short_desc, long_desc)
                code_key = node.code.replace('.', '')  # Key without dots for consistent lookup
                self.nodes_dict[code_key] = node

        # Second pass: Set parent-child relationships
        for code_key, node in self.nodes_dict.items():
            parent_code_key = self.find_parent_code(code_key)
            if parent_code_key:
                parent_node = self.nodes_dict[parent_code_key]
                parent_node._add_child(node)
            else:
                # No parent found, this node is a root
                self.roots.append(node)

        # Sort children of all nodes
        for code_key, node in self.nodes_dict.items():
            node._sort_children()

        # Set levels for all nodes
        for root in self.roots:
            root._set_levels()

    def find_parent_code(self, code_key):
        """Find the parent code by progressively stripping characters from the end."""
        for i in range(len(code_key) - 1, 0, -1):
            parent_code = code_key[:i]
            if parent_code in self.nodes_dict:
                return parent_code
        return None

    def print_tree(self):
        """Print all root nodes and their subtrees."""
        for root in self.roots:
            root.print_node()

    def get_node_by_code(self, code):
        """Retrieve a node by its code (with or without dots)."""
        code_key = code.replace('.', '')
        return self.nodes_dict.get(code_key)