Upload src/judge_correction_flow.py with huggingface_hub
Browse files- src/judge_correction_flow.py +115 -0
src/judge_correction_flow.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
|
| 3 |
+
class JudgeCorrectionFlow:
|
| 4 |
+
"""
|
| 5 |
+
β-Lobeの検証結果に基づき、回答を「承認」「自動修正」「再生成」の
|
| 6 |
+
いずれのアクションに振り分け、実行を制御します。
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, alpha_lobe, beta_lobe):
|
| 10 |
+
self.alpha_lobe = alpha_lobe
|
| 11 |
+
self.beta_lobe = beta_lobe
|
| 12 |
+
|
| 13 |
+
def _summarize_and_decide_action(self, validation_result: dict) -> str:
|
| 14 |
+
"""検証結果の深刻度に基づき、次のアクションを決定する。"""
|
| 15 |
+
severity = validation_result.get("severity", "none")
|
| 16 |
+
|
| 17 |
+
if severity == "critical":
|
| 18 |
+
return "regenerate"
|
| 19 |
+
if severity == "moderate":
|
| 20 |
+
# 中程度の問題が1つでもあれば再生成を試みる(より安全な方針)
|
| 21 |
+
return "regenerate"
|
| 22 |
+
|
| 23 |
+
# 軽微な問題や問題なしの場合は承認
|
| 24 |
+
return "approve"
|
| 25 |
+
|
| 26 |
+
def _construct_regeneration_feedback(self, validation_result: dict) -> str:
|
| 27 |
+
"""再生成を指示するためのフィードバック文を構築する。"""
|
| 28 |
+
feedback_parts = []
|
| 29 |
+
for check_name, check_result in validation_result.get("checks", {}).items():
|
| 30 |
+
issues = check_result.get("contradictions", []) + check_result.get("logical_errors", []) + check_result.get("issues", [])
|
| 31 |
+
for issue in issues[:2]: # 各カテゴリから最大2件
|
| 32 |
+
issue_type = issue.get("type", "issue")
|
| 33 |
+
detail = issue.get("fact", issue.get("message", "詳細不明"))
|
| 34 |
+
feedback_parts.append(f"✗ {issue_type}: {detail}")
|
| 35 |
+
return "\n".join(feedback_parts)
|
| 36 |
+
|
| 37 |
+
async def _auto_correct_response(self, original_response: str, validation_result: dict) -> str:
|
| 38 |
+
"""軽微な問題を自動修正する。"""
|
| 39 |
+
corrected_response = original_response
|
| 40 |
+
recommendations = validation_result.get("recommendations", [])
|
| 41 |
+
|
| 42 |
+
for rec in recommendations:
|
| 43 |
+
if rec['type'] == 'fact_correction':
|
| 44 |
+
if rec['current_statement'] in corrected_response:
|
| 45 |
+
corrected_response = corrected_response.replace(rec['current_statement'], rec['correct_statement'])
|
| 46 |
+
return corrected_response
|
| 47 |
+
|
| 48 |
+
async def process_and_correct(self, question: str, db_context: dict, session_context=None, web_results=None, max_regenerations: int = 1, domain_id: str = "medical"):
|
| 49 |
+
"""
|
| 50 |
+
質問を処理し、生成、検証、修正/再生成の完全なフローを実行する。
|
| 51 |
+
ドメイン対応版。
|
| 52 |
+
"""
|
| 53 |
+
regeneration_count = 0
|
| 54 |
+
print(f" -> JudgeCorrectionFlow開始 (domain={domain_id})")
|
| 55 |
+
|
| 56 |
+
# α-Lobeで初回回答生成(ドメイン対応)
|
| 57 |
+
alpha_response = await self.alpha_lobe.generate_response(
|
| 58 |
+
question, db_context, session_context, domain_id=domain_id
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
while regeneration_count <= max_regenerations:
|
| 62 |
+
# β-Lobeで検証(ドメイン対応)
|
| 63 |
+
validation = await self.beta_lobe.validate_response(
|
| 64 |
+
question, alpha_response, db_context, web_results, session_context, domain=domain_id
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# アクションを決定
|
| 68 |
+
action = self._summarize_and_decide_action(validation)
|
| 69 |
+
print(f" -> 検証結果: severity={validation.get('severity')}, action={action}")
|
| 70 |
+
|
| 71 |
+
if action == "approve":
|
| 72 |
+
return {
|
| 73 |
+
"status": "approved",
|
| 74 |
+
"response": alpha_response["main_response"],
|
| 75 |
+
"structured": alpha_response.get("structured", {}),
|
| 76 |
+
"confidence": alpha_response.get("confidence", 0.0),
|
| 77 |
+
"validation": validation,
|
| 78 |
+
"domain": domain_id
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
if action == "auto_correct":
|
| 82 |
+
corrected_text = await self._auto_correct_response(alpha_response["main_response"], validation)
|
| 83 |
+
alpha_response["main_response"] = corrected_text
|
| 84 |
+
second_validation = await self.beta_lobe.validate_response(
|
| 85 |
+
question, alpha_response, db_context, web_results, session_context, domain=domain_id
|
| 86 |
+
)
|
| 87 |
+
return {
|
| 88 |
+
"status": "corrected",
|
| 89 |
+
"response": corrected_text,
|
| 90 |
+
"original_validation": validation,
|
| 91 |
+
"final_validation": second_validation,
|
| 92 |
+
"domain": domain_id
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
if action == "regenerate":
|
| 96 |
+
if regeneration_count < max_regenerations:
|
| 97 |
+
regeneration_count += 1
|
| 98 |
+
print(f" -> 再生成試行 {regeneration_count}/{max_regenerations}")
|
| 99 |
+
feedback = self._construct_regeneration_feedback(validation)
|
| 100 |
+
regeneration_prompt = f"前回の回答に以下の問題がありました:\n{feedback}\n\n元の質問: {question}\n\nこれらの点を修正して、再度回答してください。"
|
| 101 |
+
|
| 102 |
+
# α-Lobeにフィードバックを与えて再生成(ドメイン対応)
|
| 103 |
+
alpha_response = await self.alpha_lobe.generate_response(
|
| 104 |
+
regeneration_prompt, db_context, session_context, domain_id=domain_id
|
| 105 |
+
)
|
| 106 |
+
continue
|
| 107 |
+
else:
|
| 108 |
+
return {
|
| 109 |
+
"status": "unable_to_answer",
|
| 110 |
+
"reason": "再生成の上限に達しましたが、問題が解決しませんでした。",
|
| 111 |
+
"final_validation": validation,
|
| 112 |
+
"domain": domain_id
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
return {"status": "error", "message": "予期せぬエラーが発生しました。", "domain": domain_id}
|