Upload src/judge_beta_lobe_basic.py with huggingface_hub
Browse files- src/judge_beta_lobe_basic.py +139 -0
src/judge_beta_lobe_basic.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
class BetaLobeBasic:
|
| 5 |
+
"""
|
| 6 |
+
検証院(β-Lobe)の基本機能:Anchor事実との矛盾検出を実装。
|
| 7 |
+
"""
|
| 8 |
+
def __init__(self, db_interface, medical_ontology):
|
| 9 |
+
self.db = db_interface
|
| 10 |
+
self.ontology = medical_ontology
|
| 11 |
+
self.validation_history = []
|
| 12 |
+
|
| 13 |
+
def _is_mentioned(self, fact: str, response: str) -> bool:
|
| 14 |
+
"""事実に関連するキーワードが回答に含まれているか簡易的に判定"""
|
| 15 |
+
# 事実から主要な名詞を抽出(簡易的な実装)
|
| 16 |
+
fact_keywords = [word for word in fact.split() if len(word) > 1]
|
| 17 |
+
if not fact_keywords: return False
|
| 18 |
+
|
| 19 |
+
mentioned_count = sum(1 for kw in fact_keywords if kw in response)
|
| 20 |
+
return (mentioned_count / len(fact_keywords)) > 0.5
|
| 21 |
+
|
| 22 |
+
def _detect_numerical_contradiction(self, fact: str, response: str) -> bool:
|
| 23 |
+
"""数値の矛盾を検出"""
|
| 24 |
+
fact_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', fact)
|
| 25 |
+
if not fact_numbers: return False
|
| 26 |
+
fact_value = float(fact_numbers[0])
|
| 27 |
+
|
| 28 |
+
response_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', response)
|
| 29 |
+
if not response_numbers: return True # 事実には数値があるが、回答にはない
|
| 30 |
+
|
| 31 |
+
# 回答内の最も近い数値が、事実の数値と10%以上乖離していれば矛盾
|
| 32 |
+
is_far = all(abs(float(res_val) - fact_value) / fact_value > 0.1 for res_val in response_numbers)
|
| 33 |
+
return is_far
|
| 34 |
+
|
| 35 |
+
def _detect_contradiction(self, fact: str, response: str, fact_type: str) -> bool:
|
| 36 |
+
"""事実のタイプに応じて矛盾検出ロジックを振り分け"""
|
| 37 |
+
if fact_type == "numerical":
|
| 38 |
+
return self._detect_numerical_contradiction(fact, response)
|
| 39 |
+
# 他のfact_type(categorical, causal)の実装はWeek 10
|
| 40 |
+
else:
|
| 41 |
+
# デフォルト:否定語の存在で簡易的に判定
|
| 42 |
+
negations = ["ない", "ではなく", "ではない", "誤り", "間違い"]
|
| 43 |
+
if any(neg in response for neg in negations) and self._is_mentioned(fact, response):
|
| 44 |
+
return True
|
| 45 |
+
return False
|
| 46 |
+
|
| 47 |
+
def _extract_relevant_excerpt(self, fact: str, response: str) -> str:
|
| 48 |
+
"""事実に関連する回答の抜粋を抽出"""
|
| 49 |
+
keywords = [word for word in fact.split() if len(word) > 1][:3]
|
| 50 |
+
sentences = response.split("。")
|
| 51 |
+
for sentence in sentences:
|
| 52 |
+
if any(kw in sentence for kw in keywords):
|
| 53 |
+
return sentence.strip() + "。"
|
| 54 |
+
return response[:100] + "..."
|
| 55 |
+
|
| 56 |
+
async def check_anchor_facts(self, response_text: str, db_context: dict) -> dict:
|
| 57 |
+
"""Anchor事実との矛盾を検出する"""
|
| 58 |
+
contradictions = []
|
| 59 |
+
|
| 60 |
+
for coord, tile in db_context.items():
|
| 61 |
+
anchor_facts = tile.get("anchor_facts", [])
|
| 62 |
+
for fact in anchor_facts:
|
| 63 |
+
fact_text = fact.get("text", "")
|
| 64 |
+
fact_type = fact.get("type", "causal")
|
| 65 |
+
|
| 66 |
+
if not self._is_mentioned(fact_text, response_text):
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
if self._detect_contradiction(fact_text, response_text, fact_type):
|
| 70 |
+
contradictions.append({
|
| 71 |
+
"type": "anchor_fact_contradiction",
|
| 72 |
+
"coordinate": coord,
|
| 73 |
+
"fact": fact_text,
|
| 74 |
+
"fact_type": fact_type,
|
| 75 |
+
"response_excerpt": self._extract_relevant_excerpt(fact_text, response_text),
|
| 76 |
+
"severity": "critical",
|
| 77 |
+
})
|
| 78 |
+
|
| 79 |
+
return {
|
| 80 |
+
"contradictions": contradictions,
|
| 81 |
+
"contradiction_count": len(contradictions),
|
| 82 |
+
"passed": len(contradictions) == 0
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
async def validate_response_basic(self, alpha_response: dict, db_context: dict) -> dict:
|
| 86 |
+
"""
|
| 87 |
+
α-Lobeの回答を検証する(Week 9の基本機能版)。
|
| 88 |
+
Anchor事実チェックのみを行う。
|
| 89 |
+
"""
|
| 90 |
+
response_text = alpha_response.get("main_response", "")
|
| 91 |
+
|
| 92 |
+
# ステップ1: Anchor事実との矛盾検出
|
| 93 |
+
anchor_check_result = await self.check_anchor_facts(response_text, db_context)
|
| 94 |
+
|
| 95 |
+
# ステップ2: 検証結果を構造化
|
| 96 |
+
validation_result = {
|
| 97 |
+
"timestamp": datetime.now().isoformat(),
|
| 98 |
+
"response_text": response_text,
|
| 99 |
+
"checks": {"anchor_facts": anchor_check_result},
|
| 100 |
+
"has_contradictions": anchor_check_result["contradiction_count"] > 0,
|
| 101 |
+
"severity": "critical" if anchor_check_result["contradiction_count"] > 0 else "none",
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
self.validation_history.append(validation_result)
|
| 105 |
+
return validation_result
|
| 106 |
+
|
| 107 |
+
# --- 使用例 ---
|
| 108 |
+
async def main():
|
| 109 |
+
class MockDB:
|
| 110 |
+
pass # この例ではdb_contextを直接渡すため、DBインターフェースは不要
|
| 111 |
+
|
| 112 |
+
beta_lobe = BetaLobeBasic(MockDB(), None)
|
| 113 |
+
|
| 114 |
+
# --- ケース1: 矛盾あり ---
|
| 115 |
+
print("--- Case 1: Contradiction Test ---")
|
| 116 |
+
alpha_res_1 = {"main_response": "心筋梗塞は脳の血流が悪くなることで発生します。"}
|
| 117 |
+
db_ctx_1 = {
|
| 118 |
+
(28, 55, 15): {
|
| 119 |
+
"anchor_facts": [{"text": "心筋梗塞は心臓の冠動脈が詰まることで起こる", "type": "causal"}]
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
validation_1 = await beta_lobe.validate_response_basic(alpha_res_1, db_ctx_1)
|
| 123 |
+
import json
|
| 124 |
+
print(json.dumps(validation_1, indent=2, ensure_ascii=False))
|
| 125 |
+
|
| 126 |
+
# --- ケース2: 矛盾なし ---
|
| 127 |
+
print("\n--- Case 2: No Contradiction Test ---")
|
| 128 |
+
alpha_res_2 = {"main_response": "心筋梗塞の死亡率は約5%です。"}
|
| 129 |
+
db_ctx_2 = {
|
| 130 |
+
(28, 85, 15): {
|
| 131 |
+
"anchor_facts": [{"text": "心筋梗塞の急性期死亡率は約5-10%", "type": "numerical"}]
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
validation_2 = await beta_lobe.validate_response_basic(alpha_res_2, db_ctx_2)
|
| 135 |
+
print(json.dumps(validation_2, indent=2, ensure_ascii=False))
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
import asyncio
|
| 139 |
+
asyncio.run(main())
|