from sentence_transformers import SentenceTransformer from icdnode import ICDNode import globals import sys import faiss import numpy as np import os import zipfile import requests import shutil def get_tree(file_path = "data/code_descriptions/icd10cm-order-2025.txt", st_embedding_model = 'intfloat/e5-small-v2'): # ok, so we're going to cache the result of the tree building in a pickle file, with a name determined based on the file_path and st_embedding_model # if the pickle file exists, we'll load it and return it # if it doesn't exist, we'll build the tree, save it to the pickle file, and return it # first, let's determine the pickle file name, using a hash of the file_path and st_embedding_model import hashlib m = hashlib.md5() m.update(file_path.encode('utf-8')) m.update(st_embedding_model.encode('utf-8')) pickle_file_name = "data/" + m.hexdigest() + ".pkl" # now, let's try to load the pickle file import pickle icd_tree = None try: with open(pickle_file_name, 'rb') as f: icd_tree = pickle.load(f) sys.stderr.write("Loaded tree from pickle file\n") except FileNotFoundError: try: url = "https://ftp.cdc.gov/pub/Health_Statistics/NCHS/Publications/ICD10CM/2025/ICD10-CM%20Code%20Descriptions%202025.zip" zip_file_path = "data/code_descriptions/code_descriptions.zip" # Ensure the directory exists os.makedirs("data/code_descriptions", exist_ok=True) # Download the file if it doesn't exist if not os.path.exists(zip_file_path): try: sys.stderr.write("Downloading ICD data...\n") response = requests.get(url) if response.status_code != 200: sys.stderr.write(f"Failed to download file: {response.status_code}\n") sys.exit(1) with open(zip_file_path, 'wb') as f: f.write(response.content) except Exception as e: sys.stderr.write(f"Error downloading file: {e}\n") sys.exit(1) # Unzip the file try: sys.stderr.write("Unzipping ICD data...\n") with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: zip_ref.extractall("data/code_descriptions") except zipfile.BadZipFile: sys.stderr.write("Error: The downloaded file is not a valid ZIP file.\n") sys.exit(1) # Copy the file to the desired location file_path = "data/code_descriptions/icd10cm-order-2025.txt" if not os.path.exists(file_path): try: sys.stderr.write("Copying ICD data...\n") shutil.copy("data/code_descriptions/icd10cm-order-2025.txt", file_path) except Exception as e: sys.stderr.write(f"Error copying file: {e}\n") sys.exit(1) except Exception as e: sys.stderr.write("Error downloading or unzipping the file: " + str(e)) return None # if the file doesn't exist, we'll build the tree sys.stderr.write("Building tree\n") icd_tree = ICDTree(file_path, st_embedding_model) # and save it to the pickle file with open(pickle_file_name, 'wb') as f: pickle.dump(icd_tree, f) icd_tree.build_faiss_index() return icd_tree class ICDTree: def __init__(self, path, st_embedding_model): self.nodes_dict = {} # Dictionary to hold all nodes by code without dots self.roots = [] # List to hold root nodes self.build_tree(path) globals.total_nodes = len(self.nodes_dict) # e5-small-v2 seems to work well on synonym mapping (and not gated unlike the jina model): https://arxiv.org/html/2401.01943v2 # intfloat's e5 models all require a "query: " prefix for semantic similarity tasks: see e.g. https://huggingface.co/intfloat/e5-small-v2 self.embedding_model = SentenceTransformer(st_embedding_model) prefix = "" if "intfloat" in st_embedding_model and "e5" in st_embedding_model: prefix = "query: " for root in self.roots: root._add_embeddings(self.embedding_model, prefix) def build_faiss_index(self): # we need to keep a faiss index for the embeddings # IndexFlatL2 seems to work slightly better than IndexFlatIP, at least for the one query I tried ('hurt tummy') self.index = faiss.IndexFlatL2(self.embedding_model.get_sentence_embedding_dimension()) #self.index = faiss.IndexFlatIP(self.embedding_model.get_sentence_embedding_dimension()) embeddings = [node.embedding for node in self.nodes_dict.values()] self.index.add(np.array(embeddings)) def semantic_search(self, query, limit = 10, offset = 0): """Search for nodes similar to the query using the sentence transformer model.""" query_embedding = self.embedding_model.encode([query])[0] # search the faiss index D, I = self.index.search(np.array([query_embedding]), limit + offset) results = [] for i in range(limit + offset): node = list(self.nodes_dict.values())[I[0][i]] results.append((node, float(D[0][i]))) # sort by the node's level # results.sort(key=lambda x: x[0].level) # sort results by distance results.sort(key=lambda x: x[1]) return results[offset:] # def get_descendants_yaml(self, code): # """Returns a YAML string representing the node and its descendants, with indentation.""" # root = self.get_node_by_code(code) # if root: # return root.get_descendants_yaml() # else: # return None def regex_search(self, pattern, max_depth=None, valid = None): """Returns a list of node objects whose short_desc, long_desc, or code matches the regex pattern. max_depth is the maximum depth to search in the hierarchy; if None, search all levels. If valid is set to 1, only return valid nodes; if 0, only return invalid nodes. If None, return all matching nodes. """ results = [] for root in self.roots: results.extend(root.regex_search(pattern, max_depth = max_depth, valid = valid)) return results def build_tree(self, file_path): # First pass: Read the file and create all nodes with open(file_path, 'r') as f: for line in f: order = line[0:5].strip() code = line[6:13].strip() valid = line[14:15].strip() short_desc = line[16:76].strip() long_desc = line[77:].strip() node = ICDNode(order, code, valid, short_desc, long_desc) code_key = node.code.replace('.', '') # Key without dots for consistent lookup self.nodes_dict[code_key] = node # Second pass: Set parent-child relationships for code_key, node in self.nodes_dict.items(): parent_code_key = self.find_parent_code(code_key) if parent_code_key: parent_node = self.nodes_dict[parent_code_key] parent_node._add_child(node) else: # No parent found, this node is a root self.roots.append(node) # Sort children of all nodes for code_key, node in self.nodes_dict.items(): node._sort_children() # Set levels for all nodes for root in self.roots: root._set_levels() def find_parent_code(self, code_key): """Find the parent code by progressively stripping characters from the end.""" for i in range(len(code_key) - 1, 0, -1): parent_code = code_key[:i] if parent_code in self.nodes_dict: return parent_code return None def print_tree(self): """Print all root nodes and their subtrees.""" for root in self.roots: root.print_node() def get_node_by_code(self, code): """Retrieve a node by its code (with or without dots).""" code_key = code.replace('.', '') return self.nodes_dict.get(code_key)