icdbot / icdtree.py
oneilsh's picture
trying huggingface spaces...
c0aa211
from sentence_transformers import SentenceTransformer
from icdnode import ICDNode
import globals
import sys
import faiss
import numpy as np
import os
import zipfile
import requests
import shutil
def get_tree(file_path = "data/code_descriptions/icd10cm-order-2025.txt", st_embedding_model = 'intfloat/e5-small-v2'):
# ok, so we're going to cache the result of the tree building in a pickle file, with a name determined based on the file_path and st_embedding_model
# if the pickle file exists, we'll load it and return it
# if it doesn't exist, we'll build the tree, save it to the pickle file, and return it
# first, let's determine the pickle file name, using a hash of the file_path and st_embedding_model
import hashlib
m = hashlib.md5()
m.update(file_path.encode('utf-8'))
m.update(st_embedding_model.encode('utf-8'))
pickle_file_name = "data/" + m.hexdigest() + ".pkl"
# now, let's try to load the pickle file
import pickle
icd_tree = None
try:
with open(pickle_file_name, 'rb') as f:
icd_tree = pickle.load(f)
sys.stderr.write("Loaded tree from pickle file\n")
except FileNotFoundError:
try:
url = "https://ftp.cdc.gov/pub/Health_Statistics/NCHS/Publications/ICD10CM/2025/ICD10-CM%20Code%20Descriptions%202025.zip"
zip_file_path = "data/code_descriptions/code_descriptions.zip"
# Ensure the directory exists
os.makedirs("data/code_descriptions", exist_ok=True)
# Download the file if it doesn't exist
if not os.path.exists(zip_file_path):
try:
sys.stderr.write("Downloading ICD data...\n")
response = requests.get(url)
if response.status_code != 200:
sys.stderr.write(f"Failed to download file: {response.status_code}\n")
sys.exit(1)
with open(zip_file_path, 'wb') as f:
f.write(response.content)
except Exception as e:
sys.stderr.write(f"Error downloading file: {e}\n")
sys.exit(1)
# Unzip the file
try:
sys.stderr.write("Unzipping ICD data...\n")
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall("data/code_descriptions")
except zipfile.BadZipFile:
sys.stderr.write("Error: The downloaded file is not a valid ZIP file.\n")
sys.exit(1)
# Copy the file to the desired location
file_path = "data/code_descriptions/icd10cm-order-2025.txt"
if not os.path.exists(file_path):
try:
sys.stderr.write("Copying ICD data...\n")
shutil.copy("data/code_descriptions/icd10cm-order-2025.txt", file_path)
except Exception as e:
sys.stderr.write(f"Error copying file: {e}\n")
sys.exit(1)
except Exception as e:
sys.stderr.write("Error downloading or unzipping the file: " + str(e))
return None
# if the file doesn't exist, we'll build the tree
sys.stderr.write("Building tree\n")
icd_tree = ICDTree(file_path, st_embedding_model)
# and save it to the pickle file
with open(pickle_file_name, 'wb') as f:
pickle.dump(icd_tree, f)
icd_tree.build_faiss_index()
return icd_tree
class ICDTree:
def __init__(self, path, st_embedding_model):
self.nodes_dict = {} # Dictionary to hold all nodes by code without dots
self.roots = [] # List to hold root nodes
self.build_tree(path)
globals.total_nodes = len(self.nodes_dict)
# e5-small-v2 seems to work well on synonym mapping (and not gated unlike the jina model): https://arxiv.org/html/2401.01943v2
# intfloat's e5 models all require a "query: " prefix for semantic similarity tasks: see e.g. https://huggingface.co/intfloat/e5-small-v2
self.embedding_model = SentenceTransformer(st_embedding_model)
prefix = ""
if "intfloat" in st_embedding_model and "e5" in st_embedding_model:
prefix = "query: "
for root in self.roots:
root._add_embeddings(self.embedding_model, prefix)
def build_faiss_index(self):
# we need to keep a faiss index for the embeddings
# IndexFlatL2 seems to work slightly better than IndexFlatIP, at least for the one query I tried ('hurt tummy')
self.index = faiss.IndexFlatL2(self.embedding_model.get_sentence_embedding_dimension())
#self.index = faiss.IndexFlatIP(self.embedding_model.get_sentence_embedding_dimension())
embeddings = [node.embedding for node in self.nodes_dict.values()]
self.index.add(np.array(embeddings))
def semantic_search(self, query, limit = 10, offset = 0):
"""Search for nodes similar to the query using the sentence transformer model."""
query_embedding = self.embedding_model.encode([query])[0]
# search the faiss index
D, I = self.index.search(np.array([query_embedding]), limit + offset)
results = []
for i in range(limit + offset):
node = list(self.nodes_dict.values())[I[0][i]]
results.append((node, float(D[0][i])))
# sort by the node's level
# results.sort(key=lambda x: x[0].level)
# sort results by distance
results.sort(key=lambda x: x[1])
return results[offset:]
# def get_descendants_yaml(self, code):
# """Returns a YAML string representing the node and its descendants, with indentation."""
# root = self.get_node_by_code(code)
# if root:
# return root.get_descendants_yaml()
# else:
# return None
def regex_search(self, pattern, max_depth=None, valid = None):
"""Returns a list of node objects whose short_desc, long_desc, or
code matches the regex pattern. max_depth is the maximum depth to
search in the hierarchy; if None, search all levels.
If valid is set to 1, only return valid nodes; if 0, only return invalid nodes. If None, return all matching nodes.
"""
results = []
for root in self.roots:
results.extend(root.regex_search(pattern, max_depth = max_depth, valid = valid))
return results
def build_tree(self, file_path):
# First pass: Read the file and create all nodes
with open(file_path, 'r') as f:
for line in f:
order = line[0:5].strip()
code = line[6:13].strip()
valid = line[14:15].strip()
short_desc = line[16:76].strip()
long_desc = line[77:].strip()
node = ICDNode(order, code, valid, short_desc, long_desc)
code_key = node.code.replace('.', '') # Key without dots for consistent lookup
self.nodes_dict[code_key] = node
# Second pass: Set parent-child relationships
for code_key, node in self.nodes_dict.items():
parent_code_key = self.find_parent_code(code_key)
if parent_code_key:
parent_node = self.nodes_dict[parent_code_key]
parent_node._add_child(node)
else:
# No parent found, this node is a root
self.roots.append(node)
# Sort children of all nodes
for code_key, node in self.nodes_dict.items():
node._sort_children()
# Set levels for all nodes
for root in self.roots:
root._set_levels()
def find_parent_code(self, code_key):
"""Find the parent code by progressively stripping characters from the end."""
for i in range(len(code_key) - 1, 0, -1):
parent_code = code_key[:i]
if parent_code in self.nodes_dict:
return parent_code
return None
def print_tree(self):
"""Print all root nodes and their subtrees."""
for root in self.roots:
root.print_node()
def get_node_by_code(self, code):
"""Retrieve a node by its code (with or without dots)."""
code_key = code.replace('.', '')
return self.nodes_dict.get(code_key)