import re from sentence_transformers import SentenceTransformer import globals import sys class ICDNode: def __init__(self, order, code, valid, short_desc, long_desc): self.order = order code = code.replace('.', '') if len(code) > 3: self.code = code[:3] + '.' + code[3:] else: self.code = code self.valid = valid self.short_desc = short_desc self.long_desc = long_desc self.children = [] self.parent = None self.embedding = None self.level = 0 def _set_levels(self): """The level of this node is the level of it's parent (if it's not none). The level of children is one more than the level of the parent.""" if self.parent: self.level = self.parent.level + 1 for child in self.children: child._set_levels() def get_parent(self): """Return the parent node of this node.""" return self.parent def get_children(self): """Return the immediate children nodes of this node.""" return self.children def get_descendants(self): """Return the children nodes of this node and all subnodes as a list, in a depth-first order.""" children = [] for child in self.children: children.append(child) children.extend(child.get_descendants()) return children def get_ancestors(self): """Return the parent nodes of this node and all super nodes as a list.""" ancestors = [] parent = self.parent while parent: ancestors.append(parent) parent = parent.parent return ancestors def _add_embeddings(self, embedding_model, prefix = ""): """Add embeddings to this node and all its descendants.""" self.embedding = embedding_model.encode([prefix + self.short_desc])[0] globals.num_embedded += 1 percent_complete = round(100 * (globals.num_embedded / globals.total_nodes), 2) sys.stderr.write(f"({percent_complete}%) " + str(self) + "\n") for child in self.children: child._add_embeddings(embedding_model, prefix) def _sort_children(self): """Sort the children nodes by their code.""" self.children.sort(key=lambda x: x.code) def _add_child(self, child_node): """Add a child node to this node.""" self.children.append(child_node) child_node.parent = self def regex_search(self, pattern, max_depth=None, current_depth=0, valid = None): """Returns a list of node objects whose short_desc, long_desc, or code matches the regex pattern. If valid is set to True, only return valid nodes; if False, only return invalid nodes. If None, return all matching nodes. max_depth is the maximum depth to search.""" if max_depth is None: max_depth = float('inf') # only return the node if it matches the pattern and is within the max_depth in the hierarchy results = [] if re.search(pattern, self.short_desc, re.IGNORECASE) or re.search(pattern, self.long_desc, re.IGNORECASE) or re.search(pattern, self.code, re.IGNORECASE): if valid is None or self.valid == valid: results.append(self) if current_depth < max_depth: for child in self.children: if valid is None or child.valid == valid: results.extend(child.regex_search(pattern, max_depth = max_depth, current_depth = current_depth + 1, valid = valid)) return results def __repr__(self): embed_str_rep = "[" + ", ".join([str(round(x, 5)) for x in self.embedding[:4]]) + ", ...]" if hasattr(self, 'embedding') else None return f"Code: {self.code}\tLevel: {self.level}\tValid: {self.valid}\tDesc: {self.short_desc}\tVec: {embed_str_rep}"