icdbot / icdnode.py
oneilsh's picture
trying huggingface spaces...
c0aa211
raw
history blame
3.91 kB
import re
from sentence_transformers import SentenceTransformer
import globals
import sys
class ICDNode:
def __init__(self, order, code, valid, short_desc, long_desc):
self.order = order
code = code.replace('.', '')
if len(code) > 3:
self.code = code[:3] + '.' + code[3:]
else:
self.code = code
self.valid = valid
self.short_desc = short_desc
self.long_desc = long_desc
self.children = []
self.parent = None
self.embedding = None
self.level = 0
def _set_levels(self):
"""The level of this node is the level of it's parent (if it's not none). The level of children is one more than the level of the parent."""
if self.parent:
self.level = self.parent.level + 1
for child in self.children:
child._set_levels()
def get_parent(self):
"""Return the parent node of this node."""
return self.parent
def get_children(self):
"""Return the immediate children nodes of this node."""
return self.children
def get_descendants(self):
"""Return the children nodes of this node and all subnodes as a list, in a depth-first order."""
children = []
for child in self.children:
children.append(child)
children.extend(child.get_descendants())
return children
def get_ancestors(self):
"""Return the parent nodes of this node and all super nodes as a list."""
ancestors = []
parent = self.parent
while parent:
ancestors.append(parent)
parent = parent.parent
return ancestors
def _add_embeddings(self, embedding_model, prefix = ""):
"""Add embeddings to this node and all its descendants."""
self.embedding = embedding_model.encode([prefix + self.short_desc])[0]
globals.num_embedded += 1
percent_complete = round(100 * (globals.num_embedded / globals.total_nodes), 2)
sys.stderr.write(f"({percent_complete}%) " + str(self) + "\n")
for child in self.children:
child._add_embeddings(embedding_model, prefix)
def _sort_children(self):
"""Sort the children nodes by their code."""
self.children.sort(key=lambda x: x.code)
def _add_child(self, child_node):
"""Add a child node to this node."""
self.children.append(child_node)
child_node.parent = self
def regex_search(self, pattern, max_depth=None, current_depth=0, valid = None):
"""Returns a list of node objects whose short_desc, long_desc,
or code matches the regex pattern. If valid is set to True, only
return valid nodes; if False, only return invalid nodes. If None,
return all matching nodes. max_depth is the maximum depth to search."""
if max_depth is None:
max_depth = float('inf')
# only return the node if it matches the pattern and is within the max_depth in the hierarchy
results = []
if re.search(pattern, self.short_desc, re.IGNORECASE) or re.search(pattern, self.long_desc, re.IGNORECASE) or re.search(pattern, self.code, re.IGNORECASE):
if valid is None or self.valid == valid:
results.append(self)
if current_depth < max_depth:
for child in self.children:
if valid is None or child.valid == valid:
results.extend(child.regex_search(pattern, max_depth = max_depth, current_depth = current_depth + 1, valid = valid))
return results
def __repr__(self):
embed_str_rep = "[" + ", ".join([str(round(x, 5)) for x in self.embedding[:4]]) + ", ...]" if hasattr(self, 'embedding') else None
return f"Code: {self.code}\tLevel: {self.level}\tValid: {self.valid}\tDesc: {self.short_desc}\tVec: {embed_str_rep}"