File size: 8,462 Bytes
c0aa211 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
from sentence_transformers import SentenceTransformer
from icdnode import ICDNode
import globals
import sys
import faiss
import numpy as np
import os
import zipfile
import requests
import shutil
def get_tree(file_path = "data/code_descriptions/icd10cm-order-2025.txt", st_embedding_model = 'intfloat/e5-small-v2'):
# ok, so we're going to cache the result of the tree building in a pickle file, with a name determined based on the file_path and st_embedding_model
# if the pickle file exists, we'll load it and return it
# if it doesn't exist, we'll build the tree, save it to the pickle file, and return it
# first, let's determine the pickle file name, using a hash of the file_path and st_embedding_model
import hashlib
m = hashlib.md5()
m.update(file_path.encode('utf-8'))
m.update(st_embedding_model.encode('utf-8'))
pickle_file_name = "data/" + m.hexdigest() + ".pkl"
# now, let's try to load the pickle file
import pickle
icd_tree = None
try:
with open(pickle_file_name, 'rb') as f:
icd_tree = pickle.load(f)
sys.stderr.write("Loaded tree from pickle file\n")
except FileNotFoundError:
try:
url = "https://ftp.cdc.gov/pub/Health_Statistics/NCHS/Publications/ICD10CM/2025/ICD10-CM%20Code%20Descriptions%202025.zip"
zip_file_path = "data/code_descriptions/code_descriptions.zip"
# Ensure the directory exists
os.makedirs("data/code_descriptions", exist_ok=True)
# Download the file if it doesn't exist
if not os.path.exists(zip_file_path):
try:
sys.stderr.write("Downloading ICD data...\n")
response = requests.get(url)
if response.status_code != 200:
sys.stderr.write(f"Failed to download file: {response.status_code}\n")
sys.exit(1)
with open(zip_file_path, 'wb') as f:
f.write(response.content)
except Exception as e:
sys.stderr.write(f"Error downloading file: {e}\n")
sys.exit(1)
# Unzip the file
try:
sys.stderr.write("Unzipping ICD data...\n")
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall("data/code_descriptions")
except zipfile.BadZipFile:
sys.stderr.write("Error: The downloaded file is not a valid ZIP file.\n")
sys.exit(1)
# Copy the file to the desired location
file_path = "data/code_descriptions/icd10cm-order-2025.txt"
if not os.path.exists(file_path):
try:
sys.stderr.write("Copying ICD data...\n")
shutil.copy("data/code_descriptions/icd10cm-order-2025.txt", file_path)
except Exception as e:
sys.stderr.write(f"Error copying file: {e}\n")
sys.exit(1)
except Exception as e:
sys.stderr.write("Error downloading or unzipping the file: " + str(e))
return None
# if the file doesn't exist, we'll build the tree
sys.stderr.write("Building tree\n")
icd_tree = ICDTree(file_path, st_embedding_model)
# and save it to the pickle file
with open(pickle_file_name, 'wb') as f:
pickle.dump(icd_tree, f)
icd_tree.build_faiss_index()
return icd_tree
class ICDTree:
def __init__(self, path, st_embedding_model):
self.nodes_dict = {} # Dictionary to hold all nodes by code without dots
self.roots = [] # List to hold root nodes
self.build_tree(path)
globals.total_nodes = len(self.nodes_dict)
# e5-small-v2 seems to work well on synonym mapping (and not gated unlike the jina model): https://arxiv.org/html/2401.01943v2
# intfloat's e5 models all require a "query: " prefix for semantic similarity tasks: see e.g. https://huggingface.co/intfloat/e5-small-v2
self.embedding_model = SentenceTransformer(st_embedding_model)
prefix = ""
if "intfloat" in st_embedding_model and "e5" in st_embedding_model:
prefix = "query: "
for root in self.roots:
root._add_embeddings(self.embedding_model, prefix)
def build_faiss_index(self):
# we need to keep a faiss index for the embeddings
# IndexFlatL2 seems to work slightly better than IndexFlatIP, at least for the one query I tried ('hurt tummy')
self.index = faiss.IndexFlatL2(self.embedding_model.get_sentence_embedding_dimension())
#self.index = faiss.IndexFlatIP(self.embedding_model.get_sentence_embedding_dimension())
embeddings = [node.embedding for node in self.nodes_dict.values()]
self.index.add(np.array(embeddings))
def semantic_search(self, query, limit = 10, offset = 0):
"""Search for nodes similar to the query using the sentence transformer model."""
query_embedding = self.embedding_model.encode([query])[0]
# search the faiss index
D, I = self.index.search(np.array([query_embedding]), limit + offset)
results = []
for i in range(limit + offset):
node = list(self.nodes_dict.values())[I[0][i]]
results.append((node, float(D[0][i])))
# sort by the node's level
# results.sort(key=lambda x: x[0].level)
# sort results by distance
results.sort(key=lambda x: x[1])
return results[offset:]
# def get_descendants_yaml(self, code):
# """Returns a YAML string representing the node and its descendants, with indentation."""
# root = self.get_node_by_code(code)
# if root:
# return root.get_descendants_yaml()
# else:
# return None
def regex_search(self, pattern, max_depth=None, valid = None):
"""Returns a list of node objects whose short_desc, long_desc, or
code matches the regex pattern. max_depth is the maximum depth to
search in the hierarchy; if None, search all levels.
If valid is set to 1, only return valid nodes; if 0, only return invalid nodes. If None, return all matching nodes.
"""
results = []
for root in self.roots:
results.extend(root.regex_search(pattern, max_depth = max_depth, valid = valid))
return results
def build_tree(self, file_path):
# First pass: Read the file and create all nodes
with open(file_path, 'r') as f:
for line in f:
order = line[0:5].strip()
code = line[6:13].strip()
valid = line[14:15].strip()
short_desc = line[16:76].strip()
long_desc = line[77:].strip()
node = ICDNode(order, code, valid, short_desc, long_desc)
code_key = node.code.replace('.', '') # Key without dots for consistent lookup
self.nodes_dict[code_key] = node
# Second pass: Set parent-child relationships
for code_key, node in self.nodes_dict.items():
parent_code_key = self.find_parent_code(code_key)
if parent_code_key:
parent_node = self.nodes_dict[parent_code_key]
parent_node._add_child(node)
else:
# No parent found, this node is a root
self.roots.append(node)
# Sort children of all nodes
for code_key, node in self.nodes_dict.items():
node._sort_children()
# Set levels for all nodes
for root in self.roots:
root._set_levels()
def find_parent_code(self, code_key):
"""Find the parent code by progressively stripping characters from the end."""
for i in range(len(code_key) - 1, 0, -1):
parent_code = code_key[:i]
if parent_code in self.nodes_dict:
return parent_code
return None
def print_tree(self):
"""Print all root nodes and their subtrees."""
for root in self.roots:
root.print_node()
def get_node_by_code(self, code):
"""Retrieve a node by its code (with or without dots)."""
code_key = code.replace('.', '')
return self.nodes_dict.get(code_key) |