File size: 3,909 Bytes
c0aa211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import re
from sentence_transformers import SentenceTransformer
import globals
import sys

class ICDNode:
    def __init__(self, order, code, valid, short_desc, long_desc):
        self.order = order

        code = code.replace('.', '')
        if len(code) > 3:
            self.code = code[:3] + '.' + code[3:]
        else:
            self.code = code

        self.valid = valid
        self.short_desc = short_desc
        self.long_desc = long_desc
        self.children = []
        self.parent = None
        self.embedding = None
        self.level = 0

    def _set_levels(self):
        """The level of this node is the level of it's parent (if it's not none). The level of children is one more than the level of the parent."""
        if self.parent:
            self.level = self.parent.level + 1
        for child in self.children:
            child._set_levels()

    def get_parent(self):
        """Return the parent node of this node."""
        return self.parent

    def get_children(self):
        """Return the immediate children nodes of this node."""
        return self.children

    def get_descendants(self):
        """Return the children nodes of this node and all subnodes as a list, in a depth-first order."""
        children = []
        for child in self.children:
            children.append(child)
            children.extend(child.get_descendants())
        return children
    
    def get_ancestors(self):
        """Return the parent nodes of this node and all super nodes as a list."""
        ancestors = []
        parent = self.parent
        while parent:
            ancestors.append(parent)
            parent = parent.parent
        return ancestors
    
    def _add_embeddings(self, embedding_model, prefix = ""):
        """Add embeddings to this node and all its descendants."""
        self.embedding = embedding_model.encode([prefix + self.short_desc])[0]
        globals.num_embedded += 1
        percent_complete = round(100 * (globals.num_embedded / globals.total_nodes), 2)
        sys.stderr.write(f"({percent_complete}%) " + str(self) + "\n")
        for child in self.children:
            child._add_embeddings(embedding_model, prefix)

    def _sort_children(self):
        """Sort the children nodes by their code."""
        self.children.sort(key=lambda x: x.code)

    def _add_child(self, child_node):
        """Add a child node to this node."""
        self.children.append(child_node)
        child_node.parent = self

    def regex_search(self, pattern, max_depth=None, current_depth=0, valid = None):
        """Returns a list of node objects whose short_desc, long_desc, 
        or code matches the regex pattern. If valid is set to True, only
        return valid nodes; if False, only return invalid nodes. If None,
        return all matching nodes. max_depth is the maximum depth to search."""

        if max_depth is None:
            max_depth = float('inf')

        # only return the node if it matches the pattern and is within the max_depth in the hierarchy
        results = []
        if re.search(pattern, self.short_desc, re.IGNORECASE) or re.search(pattern, self.long_desc, re.IGNORECASE) or re.search(pattern, self.code, re.IGNORECASE):
            if valid is None or self.valid == valid:
                results.append(self)
        if current_depth < max_depth:
            for child in self.children:
                if valid is None or child.valid == valid:
                    results.extend(child.regex_search(pattern, max_depth = max_depth, current_depth = current_depth + 1, valid = valid))

        return results


    def __repr__(self):
        embed_str_rep = "[" + ", ".join([str(round(x, 5)) for x in self.embedding[:4]]) + ", ...]" if hasattr(self, 'embedding') else None
        return f"Code: {self.code}\tLevel: {self.level}\tValid: {self.valid}\tDesc: {self.short_desc}\tVec: {embed_str_rep}"