|
|
from sentence_transformers import SentenceTransformer |
|
|
from icdnode import ICDNode |
|
|
import globals |
|
|
import sys |
|
|
import faiss |
|
|
import numpy as np |
|
|
import os |
|
|
import zipfile |
|
|
import requests |
|
|
import shutil |
|
|
|
|
|
def get_tree(file_path = "data/code_descriptions/icd10cm-order-2025.txt", st_embedding_model = 'intfloat/e5-small-v2'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import hashlib |
|
|
m = hashlib.md5() |
|
|
m.update(file_path.encode('utf-8')) |
|
|
m.update(st_embedding_model.encode('utf-8')) |
|
|
pickle_file_name = "data/" + m.hexdigest() + ".pkl" |
|
|
|
|
|
|
|
|
import pickle |
|
|
icd_tree = None |
|
|
try: |
|
|
with open(pickle_file_name, 'rb') as f: |
|
|
icd_tree = pickle.load(f) |
|
|
sys.stderr.write("Loaded tree from pickle file\n") |
|
|
|
|
|
except FileNotFoundError: |
|
|
try: |
|
|
url = "https://ftp.cdc.gov/pub/Health_Statistics/NCHS/Publications/ICD10CM/2025/ICD10-CM%20Code%20Descriptions%202025.zip" |
|
|
zip_file_path = "data/code_descriptions/code_descriptions.zip" |
|
|
|
|
|
|
|
|
os.makedirs("data/code_descriptions", exist_ok=True) |
|
|
|
|
|
|
|
|
if not os.path.exists(zip_file_path): |
|
|
try: |
|
|
sys.stderr.write("Downloading ICD data...\n") |
|
|
response = requests.get(url) |
|
|
if response.status_code != 200: |
|
|
sys.stderr.write(f"Failed to download file: {response.status_code}\n") |
|
|
sys.exit(1) |
|
|
with open(zip_file_path, 'wb') as f: |
|
|
f.write(response.content) |
|
|
except Exception as e: |
|
|
sys.stderr.write(f"Error downloading file: {e}\n") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
try: |
|
|
sys.stderr.write("Unzipping ICD data...\n") |
|
|
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: |
|
|
zip_ref.extractall("data/code_descriptions") |
|
|
except zipfile.BadZipFile: |
|
|
sys.stderr.write("Error: The downloaded file is not a valid ZIP file.\n") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
file_path = "data/code_descriptions/icd10cm-order-2025.txt" |
|
|
if not os.path.exists(file_path): |
|
|
try: |
|
|
sys.stderr.write("Copying ICD data...\n") |
|
|
shutil.copy("data/code_descriptions/icd10cm-order-2025.txt", file_path) |
|
|
except Exception as e: |
|
|
sys.stderr.write(f"Error copying file: {e}\n") |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
sys.stderr.write("Error downloading or unzipping the file: " + str(e)) |
|
|
return None |
|
|
|
|
|
|
|
|
sys.stderr.write("Building tree\n") |
|
|
icd_tree = ICDTree(file_path, st_embedding_model) |
|
|
|
|
|
|
|
|
with open(pickle_file_name, 'wb') as f: |
|
|
pickle.dump(icd_tree, f) |
|
|
|
|
|
icd_tree.build_faiss_index() |
|
|
return icd_tree |
|
|
|
|
|
class ICDTree: |
|
|
def __init__(self, path, st_embedding_model): |
|
|
self.nodes_dict = {} |
|
|
self.roots = [] |
|
|
self.build_tree(path) |
|
|
globals.total_nodes = len(self.nodes_dict) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.embedding_model = SentenceTransformer(st_embedding_model) |
|
|
prefix = "" |
|
|
if "intfloat" in st_embedding_model and "e5" in st_embedding_model: |
|
|
prefix = "query: " |
|
|
|
|
|
for root in self.roots: |
|
|
root._add_embeddings(self.embedding_model, prefix) |
|
|
|
|
|
|
|
|
def build_faiss_index(self): |
|
|
|
|
|
|
|
|
self.index = faiss.IndexFlatL2(self.embedding_model.get_sentence_embedding_dimension()) |
|
|
|
|
|
embeddings = [node.embedding for node in self.nodes_dict.values()] |
|
|
|
|
|
self.index.add(np.array(embeddings)) |
|
|
|
|
|
|
|
|
def semantic_search(self, query, limit = 10, offset = 0): |
|
|
"""Search for nodes similar to the query using the sentence transformer model.""" |
|
|
query_embedding = self.embedding_model.encode([query])[0] |
|
|
|
|
|
|
|
|
|
|
|
D, I = self.index.search(np.array([query_embedding]), limit + offset) |
|
|
results = [] |
|
|
for i in range(limit + offset): |
|
|
node = list(self.nodes_dict.values())[I[0][i]] |
|
|
results.append((node, float(D[0][i]))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x[1]) |
|
|
|
|
|
return results[offset:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def regex_search(self, pattern, max_depth=None, valid = None): |
|
|
"""Returns a list of node objects whose short_desc, long_desc, or |
|
|
code matches the regex pattern. max_depth is the maximum depth to |
|
|
search in the hierarchy; if None, search all levels. |
|
|
|
|
|
If valid is set to 1, only return valid nodes; if 0, only return invalid nodes. If None, return all matching nodes. |
|
|
""" |
|
|
|
|
|
results = [] |
|
|
for root in self.roots: |
|
|
results.extend(root.regex_search(pattern, max_depth = max_depth, valid = valid)) |
|
|
return results |
|
|
|
|
|
def build_tree(self, file_path): |
|
|
|
|
|
with open(file_path, 'r') as f: |
|
|
for line in f: |
|
|
order = line[0:5].strip() |
|
|
code = line[6:13].strip() |
|
|
valid = line[14:15].strip() |
|
|
short_desc = line[16:76].strip() |
|
|
long_desc = line[77:].strip() |
|
|
|
|
|
node = ICDNode(order, code, valid, short_desc, long_desc) |
|
|
code_key = node.code.replace('.', '') |
|
|
self.nodes_dict[code_key] = node |
|
|
|
|
|
|
|
|
for code_key, node in self.nodes_dict.items(): |
|
|
parent_code_key = self.find_parent_code(code_key) |
|
|
if parent_code_key: |
|
|
parent_node = self.nodes_dict[parent_code_key] |
|
|
parent_node._add_child(node) |
|
|
else: |
|
|
|
|
|
self.roots.append(node) |
|
|
|
|
|
|
|
|
for code_key, node in self.nodes_dict.items(): |
|
|
node._sort_children() |
|
|
|
|
|
|
|
|
for root in self.roots: |
|
|
root._set_levels() |
|
|
|
|
|
def find_parent_code(self, code_key): |
|
|
"""Find the parent code by progressively stripping characters from the end.""" |
|
|
for i in range(len(code_key) - 1, 0, -1): |
|
|
parent_code = code_key[:i] |
|
|
if parent_code in self.nodes_dict: |
|
|
return parent_code |
|
|
return None |
|
|
|
|
|
def print_tree(self): |
|
|
"""Print all root nodes and their subtrees.""" |
|
|
for root in self.roots: |
|
|
root.print_node() |
|
|
|
|
|
def get_node_by_code(self, code): |
|
|
"""Retrieve a node by its code (with or without dots).""" |
|
|
code_key = code.replace('.', '') |
|
|
return self.nodes_dict.get(code_key) |