kofdai commited on
Commit
deb3a5d
·
verified ·
1 Parent(s): 2c50b9b

Upload src/judge_correction_flow.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/judge_correction_flow.py +115 -0
src/judge_correction_flow.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
+ class JudgeCorrectionFlow:
4
+ """
5
+ β-Lobeの検証結果に基づき、回答を「承認」「自動修正」「再生成」の
6
+ いずれのアクションに振り分け、実行を制御します。
7
+ """
8
+
9
+ def __init__(self, alpha_lobe, beta_lobe):
10
+ self.alpha_lobe = alpha_lobe
11
+ self.beta_lobe = beta_lobe
12
+
13
+ def _summarize_and_decide_action(self, validation_result: dict) -> str:
14
+ """検証結果の深刻度に基づき、次のアクションを決定する。"""
15
+ severity = validation_result.get("severity", "none")
16
+
17
+ if severity == "critical":
18
+ return "regenerate"
19
+ if severity == "moderate":
20
+ # 中程度の問題が1つでもあれば再生成を試みる(より安全な方針)
21
+ return "regenerate"
22
+
23
+ # 軽微な問題や問題なしの場合は承認
24
+ return "approve"
25
+
26
+ def _construct_regeneration_feedback(self, validation_result: dict) -> str:
27
+ """再生成を指示するためのフィードバック文を構築する。"""
28
+ feedback_parts = []
29
+ for check_name, check_result in validation_result.get("checks", {}).items():
30
+ issues = check_result.get("contradictions", []) + check_result.get("logical_errors", []) + check_result.get("issues", [])
31
+ for issue in issues[:2]: # 各カテゴリから最大2件
32
+ issue_type = issue.get("type", "issue")
33
+ detail = issue.get("fact", issue.get("message", "詳細不明"))
34
+ feedback_parts.append(f"✗ {issue_type}: {detail}")
35
+ return "\n".join(feedback_parts)
36
+
37
+ async def _auto_correct_response(self, original_response: str, validation_result: dict) -> str:
38
+ """軽微な問題を自動修正する。"""
39
+ corrected_response = original_response
40
+ recommendations = validation_result.get("recommendations", [])
41
+
42
+ for rec in recommendations:
43
+ if rec['type'] == 'fact_correction':
44
+ if rec['current_statement'] in corrected_response:
45
+ corrected_response = corrected_response.replace(rec['current_statement'], rec['correct_statement'])
46
+ return corrected_response
47
+
48
+ async def process_and_correct(self, question: str, db_context: dict, session_context=None, web_results=None, max_regenerations: int = 1, domain_id: str = "medical"):
49
+ """
50
+ 質問を処理し、生成、検証、修正/再生成の完全なフローを実行する。
51
+ ドメイン対応版。
52
+ """
53
+ regeneration_count = 0
54
+ print(f" -> JudgeCorrectionFlow開始 (domain={domain_id})")
55
+
56
+ # α-Lobeで初回回答生成(ドメイン対応)
57
+ alpha_response = await self.alpha_lobe.generate_response(
58
+ question, db_context, session_context, domain_id=domain_id
59
+ )
60
+
61
+ while regeneration_count <= max_regenerations:
62
+ # β-Lobeで検証(ドメイン対応)
63
+ validation = await self.beta_lobe.validate_response(
64
+ question, alpha_response, db_context, web_results, session_context, domain=domain_id
65
+ )
66
+
67
+ # アクションを決定
68
+ action = self._summarize_and_decide_action(validation)
69
+ print(f" -> 検証結果: severity={validation.get('severity')}, action={action}")
70
+
71
+ if action == "approve":
72
+ return {
73
+ "status": "approved",
74
+ "response": alpha_response["main_response"],
75
+ "structured": alpha_response.get("structured", {}),
76
+ "confidence": alpha_response.get("confidence", 0.0),
77
+ "validation": validation,
78
+ "domain": domain_id
79
+ }
80
+
81
+ if action == "auto_correct":
82
+ corrected_text = await self._auto_correct_response(alpha_response["main_response"], validation)
83
+ alpha_response["main_response"] = corrected_text
84
+ second_validation = await self.beta_lobe.validate_response(
85
+ question, alpha_response, db_context, web_results, session_context, domain=domain_id
86
+ )
87
+ return {
88
+ "status": "corrected",
89
+ "response": corrected_text,
90
+ "original_validation": validation,
91
+ "final_validation": second_validation,
92
+ "domain": domain_id
93
+ }
94
+
95
+ if action == "regenerate":
96
+ if regeneration_count < max_regenerations:
97
+ regeneration_count += 1
98
+ print(f" -> 再生成試行 {regeneration_count}/{max_regenerations}")
99
+ feedback = self._construct_regeneration_feedback(validation)
100
+ regeneration_prompt = f"前回の回答に以下の問題がありました:\n{feedback}\n\n元の質問: {question}\n\nこれらの点を修正して、再度回答してください。"
101
+
102
+ # α-Lobeにフィードバックを与えて再生成(ドメイン対応)
103
+ alpha_response = await self.alpha_lobe.generate_response(
104
+ regeneration_prompt, db_context, session_context, domain_id=domain_id
105
+ )
106
+ continue
107
+ else:
108
+ return {
109
+ "status": "unable_to_answer",
110
+ "reason": "再生成の上限に達しましたが、問題が解決しませんでした。",
111
+ "final_validation": validation,
112
+ "domain": domain_id
113
+ }
114
+
115
+ return {"status": "error", "message": "予期せぬエラーが発生しました。", "domain": domain_id}