|
|
import re |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import globals |
|
|
import sys |
|
|
|
|
|
class ICDNode: |
|
|
def __init__(self, order, code, valid, short_desc, long_desc): |
|
|
self.order = order |
|
|
|
|
|
code = code.replace('.', '') |
|
|
if len(code) > 3: |
|
|
self.code = code[:3] + '.' + code[3:] |
|
|
else: |
|
|
self.code = code |
|
|
|
|
|
self.valid = valid |
|
|
self.short_desc = short_desc |
|
|
self.long_desc = long_desc |
|
|
self.children = [] |
|
|
self.parent = None |
|
|
self.embedding = None |
|
|
self.level = 0 |
|
|
|
|
|
def _set_levels(self): |
|
|
"""The level of this node is the level of it's parent (if it's not none). The level of children is one more than the level of the parent.""" |
|
|
if self.parent: |
|
|
self.level = self.parent.level + 1 |
|
|
for child in self.children: |
|
|
child._set_levels() |
|
|
|
|
|
def get_parent(self): |
|
|
"""Return the parent node of this node.""" |
|
|
return self.parent |
|
|
|
|
|
def get_children(self): |
|
|
"""Return the immediate children nodes of this node.""" |
|
|
return self.children |
|
|
|
|
|
def get_descendants(self): |
|
|
"""Return the children nodes of this node and all subnodes as a list, in a depth-first order.""" |
|
|
children = [] |
|
|
for child in self.children: |
|
|
children.append(child) |
|
|
children.extend(child.get_descendants()) |
|
|
return children |
|
|
|
|
|
def get_ancestors(self): |
|
|
"""Return the parent nodes of this node and all super nodes as a list.""" |
|
|
ancestors = [] |
|
|
parent = self.parent |
|
|
while parent: |
|
|
ancestors.append(parent) |
|
|
parent = parent.parent |
|
|
return ancestors |
|
|
|
|
|
def _add_embeddings(self, embedding_model, prefix = ""): |
|
|
"""Add embeddings to this node and all its descendants.""" |
|
|
self.embedding = embedding_model.encode([prefix + self.short_desc])[0] |
|
|
globals.num_embedded += 1 |
|
|
percent_complete = round(100 * (globals.num_embedded / globals.total_nodes), 2) |
|
|
sys.stderr.write(f"({percent_complete}%) " + str(self) + "\n") |
|
|
for child in self.children: |
|
|
child._add_embeddings(embedding_model, prefix) |
|
|
|
|
|
def _sort_children(self): |
|
|
"""Sort the children nodes by their code.""" |
|
|
self.children.sort(key=lambda x: x.code) |
|
|
|
|
|
def _add_child(self, child_node): |
|
|
"""Add a child node to this node.""" |
|
|
self.children.append(child_node) |
|
|
child_node.parent = self |
|
|
|
|
|
def regex_search(self, pattern, max_depth=None, current_depth=0, valid = None): |
|
|
"""Returns a list of node objects whose short_desc, long_desc, |
|
|
or code matches the regex pattern. If valid is set to True, only |
|
|
return valid nodes; if False, only return invalid nodes. If None, |
|
|
return all matching nodes. max_depth is the maximum depth to search.""" |
|
|
|
|
|
if max_depth is None: |
|
|
max_depth = float('inf') |
|
|
|
|
|
|
|
|
results = [] |
|
|
if re.search(pattern, self.short_desc, re.IGNORECASE) or re.search(pattern, self.long_desc, re.IGNORECASE) or re.search(pattern, self.code, re.IGNORECASE): |
|
|
if valid is None or self.valid == valid: |
|
|
results.append(self) |
|
|
if current_depth < max_depth: |
|
|
for child in self.children: |
|
|
if valid is None or child.valid == valid: |
|
|
results.extend(child.regex_search(pattern, max_depth = max_depth, current_depth = current_depth + 1, valid = valid)) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def __repr__(self): |
|
|
embed_str_rep = "[" + ", ".join([str(round(x, 5)) for x in self.embedding[:4]]) + ", ...]" if hasattr(self, 'embedding') else None |
|
|
return f"Code: {self.code}\tLevel: {self.level}\tValid: {self.valid}\tDesc: {self.short_desc}\tVec: {embed_str_rep}" |
|
|
|