kofdai commited on
Commit
a95de4f
·
verified ·
1 Parent(s): ee76e96

Upload src/judge_beta_lobe_basic.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/judge_beta_lobe_basic.py +139 -0
src/judge_beta_lobe_basic.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import re
3
+
4
+ class BetaLobeBasic:
5
+ """
6
+ 検証院(β-Lobe)の基本機能:Anchor事実との矛盾検出を実装。
7
+ """
8
+ def __init__(self, db_interface, medical_ontology):
9
+ self.db = db_interface
10
+ self.ontology = medical_ontology
11
+ self.validation_history = []
12
+
13
+ def _is_mentioned(self, fact: str, response: str) -> bool:
14
+ """事実に関連するキーワードが回答に含まれているか簡易的に判定"""
15
+ # 事実から主要な名詞を抽出(簡易的な実装)
16
+ fact_keywords = [word for word in fact.split() if len(word) > 1]
17
+ if not fact_keywords: return False
18
+
19
+ mentioned_count = sum(1 for kw in fact_keywords if kw in response)
20
+ return (mentioned_count / len(fact_keywords)) > 0.5
21
+
22
+ def _detect_numerical_contradiction(self, fact: str, response: str) -> bool:
23
+ """数値の矛盾を検出"""
24
+ fact_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', fact)
25
+ if not fact_numbers: return False
26
+ fact_value = float(fact_numbers[0])
27
+
28
+ response_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', response)
29
+ if not response_numbers: return True # 事実には数値があるが、回答にはない
30
+
31
+ # 回答内の最も近い数値が、事実の数値と10%以上乖離していれば矛盾
32
+ is_far = all(abs(float(res_val) - fact_value) / fact_value > 0.1 for res_val in response_numbers)
33
+ return is_far
34
+
35
+ def _detect_contradiction(self, fact: str, response: str, fact_type: str) -> bool:
36
+ """事実のタイプに応じて矛盾検出ロジックを振り分け"""
37
+ if fact_type == "numerical":
38
+ return self._detect_numerical_contradiction(fact, response)
39
+ # 他のfact_type(categorical, causal)の実装はWeek 10
40
+ else:
41
+ # デフォルト:否定語の存在で簡易的に判定
42
+ negations = ["ない", "ではなく", "ではない", "誤り", "間違い"]
43
+ if any(neg in response for neg in negations) and self._is_mentioned(fact, response):
44
+ return True
45
+ return False
46
+
47
+ def _extract_relevant_excerpt(self, fact: str, response: str) -> str:
48
+ """事実に関連する回答の抜粋を抽出"""
49
+ keywords = [word for word in fact.split() if len(word) > 1][:3]
50
+ sentences = response.split("。")
51
+ for sentence in sentences:
52
+ if any(kw in sentence for kw in keywords):
53
+ return sentence.strip() + "。"
54
+ return response[:100] + "..."
55
+
56
+ async def check_anchor_facts(self, response_text: str, db_context: dict) -> dict:
57
+ """Anchor事実との矛盾を検出する"""
58
+ contradictions = []
59
+
60
+ for coord, tile in db_context.items():
61
+ anchor_facts = tile.get("anchor_facts", [])
62
+ for fact in anchor_facts:
63
+ fact_text = fact.get("text", "")
64
+ fact_type = fact.get("type", "causal")
65
+
66
+ if not self._is_mentioned(fact_text, response_text):
67
+ continue
68
+
69
+ if self._detect_contradiction(fact_text, response_text, fact_type):
70
+ contradictions.append({
71
+ "type": "anchor_fact_contradiction",
72
+ "coordinate": coord,
73
+ "fact": fact_text,
74
+ "fact_type": fact_type,
75
+ "response_excerpt": self._extract_relevant_excerpt(fact_text, response_text),
76
+ "severity": "critical",
77
+ })
78
+
79
+ return {
80
+ "contradictions": contradictions,
81
+ "contradiction_count": len(contradictions),
82
+ "passed": len(contradictions) == 0
83
+ }
84
+
85
+ async def validate_response_basic(self, alpha_response: dict, db_context: dict) -> dict:
86
+ """
87
+ α-Lobeの回答を検証する(Week 9の基本機能版)。
88
+ Anchor事実チェックのみを行う。
89
+ """
90
+ response_text = alpha_response.get("main_response", "")
91
+
92
+ # ステップ1: Anchor事実との矛盾検出
93
+ anchor_check_result = await self.check_anchor_facts(response_text, db_context)
94
+
95
+ # ステップ2: 検証結果を構造化
96
+ validation_result = {
97
+ "timestamp": datetime.now().isoformat(),
98
+ "response_text": response_text,
99
+ "checks": {"anchor_facts": anchor_check_result},
100
+ "has_contradictions": anchor_check_result["contradiction_count"] > 0,
101
+ "severity": "critical" if anchor_check_result["contradiction_count"] > 0 else "none",
102
+ }
103
+
104
+ self.validation_history.append(validation_result)
105
+ return validation_result
106
+
107
+ # --- 使用例 ---
108
+ async def main():
109
+ class MockDB:
110
+ pass # この例ではdb_contextを直接渡すため、DBインターフェースは不要
111
+
112
+ beta_lobe = BetaLobeBasic(MockDB(), None)
113
+
114
+ # --- ケース1: 矛盾あり ---
115
+ print("--- Case 1: Contradiction Test ---")
116
+ alpha_res_1 = {"main_response": "心筋梗塞は脳の血流が悪くなることで発生します。"}
117
+ db_ctx_1 = {
118
+ (28, 55, 15): {
119
+ "anchor_facts": [{"text": "心筋梗塞は心臓の冠動脈が詰まることで起こる", "type": "causal"}]
120
+ }
121
+ }
122
+ validation_1 = await beta_lobe.validate_response_basic(alpha_res_1, db_ctx_1)
123
+ import json
124
+ print(json.dumps(validation_1, indent=2, ensure_ascii=False))
125
+
126
+ # --- ケース2: 矛盾なし ---
127
+ print("\n--- Case 2: No Contradiction Test ---")
128
+ alpha_res_2 = {"main_response": "心筋梗塞の死亡率は約5%です。"}
129
+ db_ctx_2 = {
130
+ (28, 85, 15): {
131
+ "anchor_facts": [{"text": "心筋梗塞の急性期死亡率は約5-10%", "type": "numerical"}]
132
+ }
133
+ }
134
+ validation_2 = await beta_lobe.validate_response_basic(alpha_res_2, db_ctx_2)
135
+ print(json.dumps(validation_2, indent=2, ensure_ascii=False))
136
+
137
+ if __name__ == "__main__":
138
+ import asyncio
139
+ asyncio.run(main())