kofdai commited on
Commit
aa22d42
·
verified ·
1 Parent(s): deb3a5d

Upload src/inference_engine_unified.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/inference_engine_unified.py +177 -0
src/inference_engine_unified.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Dict, Any
3
+
4
+ # --- これまでに実装した全コンポーネントをインポート ---
5
+ from domain_manager import DomainManager
6
+ from backend.iath_db_interface import IathDBInterface
7
+ from backend.deepseek_local_client import DeepSeekLocalClient, DeepSeekConfig
8
+ from layer1_spatial_encoding import SpatialEncodingEngine
9
+ from layer2_episodic_binding import EpisodePalace
10
+ from judge_alpha_lobe import AlpheLobe
11
+ from judge_beta_lobe_advanced import BetaLobeAdvanced
12
+ from judge_correction_flow import JudgeCorrectionFlow
13
+ from layer5_state_management import ExternalState, LayerResetManager
14
+ from hallucination_detector import calculate_hallucination_risk_score
15
+
16
+ # --- モックオブジェクト(テスト用) ---
17
+ from mock_objects import MockOntology
18
+
19
+ class InferenceEngine:
20
+ """
21
+ Ilm-Athensの全推論レイヤーを統合し、単一のインターフェースを提供するクラス。
22
+ """
23
+ def __init__(self, deepseek_config: DeepSeekConfig, db_path: str):
24
+ """
25
+ 推論エンジンの初期化。すべてのコアコンポーネントをロードします。
26
+
27
+ Args:
28
+ deepseek_config (DeepSeekConfig): DeepSeekローカルクライアントの設定。
29
+ db_path (str): .iath データベースファイルのパス。
30
+ """
31
+ print("--- Ilm-Athens Unified Inference Engine Initializing... ---")
32
+ # 外部クライアントの初期化
33
+ self.llm_client = DeepSeekLocalClient(config=deepseek_config)
34
+ self.db_interface = IathDBInterface(db_file_path=db_path)
35
+
36
+ # マネージャーと汎用コンポーネントの初期化
37
+ self.domain_manager = DomainManager()
38
+ self.ontology = MockOntology() # オントロジーはまだモックを使用
39
+
40
+ # セッション管理用のストレージ
41
+ self.sessions: Dict[str, Dict[str, Any]] = {}
42
+
43
+ # DBのロード
44
+ if not self.db_interface.load_db():
45
+ print(f"警告: DBファイル '{db_path}' のロードに失敗しました。DBコンテキストなしで動作します。")
46
+
47
+ print("--- Initialization Complete. ---")
48
+
49
+ def _get_or_create_session(self, session_id: str) -> Dict[str, Any]:
50
+ """セッションIDに対応するオブジェクトを取得または新規作成する。"""
51
+ if session_id not in self.sessions:
52
+ self.sessions[session_id] = {
53
+ "episode_palace": EpisodePalace(session_id),
54
+ "external_state": ExternalState(),
55
+ "reset_manager": LayerResetManager(self.llm_client)
56
+ }
57
+ return self.sessions[session_id]
58
+
59
+ async def process_question(
60
+ self,
61
+ question: str,
62
+ session_id: str,
63
+ domain_id: str = "medical"
64
+ ) -> Dict[str, Any]:
65
+ """
66
+ 単一の質問を受け取り、推論から検証までの完全なパイプラインを実行します。
67
+
68
+ Args:
69
+ question (str): ユーザーからの質問。
70
+ session_id (str): 現在の会話セッションを識別するID。
71
+ domain_id (str): 使用する知識ドメイン (例: "medical", "legal")。
72
+
73
+ Returns:
74
+ dict: 最終的な処理結果を含む辞書。
75
+ """
76
+ print(f"\n--- Processing question in session '{session_id}' for domain '{domain_id}' ---")
77
+
78
+ # 1. セッションオブジェクトを取得
79
+ session = self._get_or_create_session(session_id)
80
+ episode_palace: EpisodePalace = session["episode_palace"]
81
+ external_state: ExternalState = session["external_state"]
82
+ reset_manager: LayerResetManager = session["reset_manager"]
83
+
84
+ # 2. ドメインスキーマと、それに基づくコンポーネントをロード
85
+ domain_schema = self.domain_manager.get_schema(domain_id)
86
+ if not domain_schema:
87
+ return {"status": "error", "message": f"ドメイン '{domain_id}' が見つかりません。"}
88
+
89
+ spatial_encoder = SpatialEncodingEngine(domain_schema, self.ontology)
90
+
91
+ # 3. パイプラインの実行
92
+ try:
93
+ # L5: ターン開始時にリセット
94
+ reset_manager.reset_layer24_for_new_turn()
95
+
96
+ # L5: 前のターンまでのコンテキストを取得
97
+ session_context = external_state.get_context_for_next_turn()
98
+
99
+ # L1: 質問から座標を抽出
100
+ coords_info = spatial_encoder.extract_coordinates_from_question(question)
101
+ db_coordinates = [info['coordinate'] for info in coords_info]
102
+
103
+ # DBからコンテキストを取得
104
+ db_context = {}
105
+ if db_coordinates:
106
+ tile = await self.db_interface.fetch_async(db_coordinates[0])
107
+ if tile:
108
+ db_context = {db_coordinates[0]: tile}
109
+
110
+ # L4: JudgeFlowをセットア���プ
111
+ # Note: α-LobeはRunnerEngineなどを内包する概念だが、ここでは直接llm_clientを渡す
112
+ beta_lobe = BetaLobeAdvanced(self.db_interface, self.ontology)
113
+ judge_flow = JudgeCorrectionFlow(AlpheLobe(None, None), beta_lobe)
114
+ # JudgeFlowのalpha_lobeを実際のLLMクライアントに差し替え
115
+ judge_flow.alpha_lobe.generate_response = self.llm_client.generate_response
116
+
117
+ # JudgeFlowを実行して、生成・検証・修正/再生成を行う
118
+ final_result = await judge_flow.process_and_correct(
119
+ question, db_context, session_context
120
+ )
121
+
122
+ # L2 & L5: ターン結果を記録
123
+ if final_result.get("status") in ["approved", "corrected"]:
124
+ response_text = final_result["response"]
125
+ episode_palace.add_turn(question, response_text, {'referenced_coords': db_coordinates})
126
+ external_state.add_turn_summary(len(episode_palace.rooms), question, response_text, db_coordinates)
127
+
128
+ return final_result
129
+
130
+ except Exception as e:
131
+ import traceback
132
+ traceback.print_exc()
133
+ return {"status": "error", "message": str(e)}
134
+
135
+ # --- 使用例 ---
136
+ async def main():
137
+ # --- 前提条件 ---
138
+ # 1. `create_tile_from_topic.py` を実行して、`sample.iath` を作成しておく
139
+ # 2. バックグラウンドでOllama等のLLMサーバーが起動している
140
+
141
+ # 1. DeepSeekクライアントの設定
142
+ # ご自身のLLM環境に合わせてURLとモデル名を修正してください
143
+ config = DeepSeekConfig(
144
+ api_url="http://localhost:11434",
145
+ model_name="gemma:2b",
146
+ )
147
+
148
+ # 2. 推論エンジンを初期化
149
+ try:
150
+ engine = InferenceEngine(
151
+ deepseek_config=config,
152
+ db_path="心筋梗塞の急性期診断アルゴリズム.iath" # 事前に生成したファイル名
153
+ )
154
+ except Exception as e:
155
+ print(f"エンジンの初期化に失敗しました: {e}")
156
+ return
157
+
158
+ # 3. 複数ターンの会話をシミュレート
159
+ session_id = "user123_session_xyz"
160
+
161
+ # ターン1
162
+ question1 = "心筋梗塞の主な原因は何ですか?"
163
+ response1 = await engine.process_question(question1, session_id, "medical")
164
+ print("\n--- Turn 1 Final Output ---")
165
+ print(json.dumps(response1, indent=2, ensure_ascii=False))
166
+
167
+ # ターン2
168
+ question2 = "その治療法について教えてください"
169
+ response2 = await engine.process_question(question2, session_id, "medical")
170
+ print("\n--- Turn 2 Final Output ---")
171
+ print(json.dumps(response2, indent=2, ensure_ascii=False))
172
+
173
+ if __name__ == "__main__":
174
+ # このスクリプトを実行する前に、
175
+ # .venv/bin/python3 create_tile_from_topic.py を実行して、
176
+ # 「心筋梗塞の急性期診断アルゴリズム.iath」というファイルを作成しておく必要があります。
177
+ asyncio.run(main())