import torch import re import os import textract from fpdf import FPDF from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel # --- Configuration --- # All paths are now local INPUT_DOC_PATH = "Doreen.doc" OUTPUT_PDF_PATH = "Doreen_DeFio_Report_Local_Test.pdf" # --- Model Paths (loading from local Hugging Face cache) --- GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized" BASE_MODEL_PATH = "unsloth/gemma-2b-it" # FIX: This now points to the local folder containing your fine-tuned model. LORA_ADAPTER_PATH = "gemma-grammar-lora" # --- Global variables for models --- grammar_model = None grammar_tokenizer = None gender_model = None gender_tokenizer = None device = "cpu" # --- 1. Model Loading Logic (from main.py) --- def load_all_models(): """Loads all AI models into memory.""" global grammar_model, grammar_tokenizer, gender_model, gender_tokenizer print("--- Starting Model Loading ---") try: print(f"Loading gender model from cache: {GENDER_MODEL_PATH}") gender_tokenizer = AutoTokenizer.from_pretrained(GENDER_MODEL_PATH) gender_model = AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH).to(device) print("✅ Gender verifier model loaded successfully!") print(f"Loading base model for grammar correction from cache: {BASE_MODEL_PATH}") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_PATH, dtype=torch.float32 ).to(device) grammar_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH) print(f"Applying LoRA adapter from local folder: {LORA_ADAPTER_PATH}") grammar_model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH).to(device) print("✅ Grammar correction model loaded successfully!") if grammar_tokenizer.pad_token is None: grammar_tokenizer.pad_token = grammar_tokenizer.eos_token if gender_tokenizer.pad_token is None: gender_tokenizer.pad_token = gender_tokenizer.eos_token except Exception as e: print(f"❌ Critical error during model loading: {e}") return False print("--- Model Loading Complete ---") return True # --- 2. Correction Functions (adapted from main.py) --- def run_grammar_correction(text: str) -> str: """Corrects grammar using the loaded LoRA model.""" if not grammar_model: return text input_text = f"Prompt: {text}\nResponse:" inputs = grammar_tokenizer(input_text, return_tensors="pt").to(device) output_ids = grammar_model.generate(**inputs, max_new_tokens=64, do_sample=False) output_text = grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True) # Cleaning logic if "Response:" in output_text: parts = output_text.split("Response:") if len(parts) > 1: return parts[1].strip() return output_text.strip() def run_gender_correction(text: str) -> str: """Corrects gender using the loaded gender model and regex.""" if not gender_model: return text input_text = f"Prompt: Please rewrite the sentence with correct grammar and gender. Output ONLY the corrected sentence:\n{text}\nResponse:" inputs = gender_tokenizer(input_text, return_tensors="pt").to(device) output_ids = gender_model.generate( **inputs, max_new_tokens=64, temperature=0.0, do_sample=False, eos_token_id=gender_tokenizer.eos_token_id ) output_text = gender_tokenizer.decode(output_ids[0], skip_special_tokens=True) # Cleaning logic if "Response:" in output_text: parts = output_text.split("Response:") if len(parts) > 1: output_text = parts[1].strip() cleaned_text = re.sub(r'^(Corrected sentence:|Correct:|Prompt:)\s*', '', output_text, flags=re.IGNORECASE).strip().strip('"') # Regex safety net corrections = { r'\bher wife\b': 'her husband', r'\bhis husband\b': 'his wife', r'\bhe is a girl\b': 'he is a boy', r'\bshe is a boy\b': 'she is a girl' } for pattern, replacement in corrections.items(): cleaned_text = re.sub(pattern, replacement, cleaned_text, flags=re.IGNORECASE) return cleaned_text # --- 3. Document Processing Logic (from document_pipeline.py) --- def extract_text_from_doc(filepath): """Extracts all text using textract.""" try: text_bytes = textract.process(filepath) return text_bytes.decode('utf-8') except Exception as e: print(f"Error reading document with textract: {e}") return None def parse_and_correct_text(raw_text): """Parses text and calls the local correction functions.""" structured_data = {} key_value_pattern = re.compile(r'^\s*(Client Name|Date of Exam|...):s*(.*)', re.IGNORECASE | re.DOTALL) # Abridged for brevity # This is the key change: we call the local functions directly # instead of making API requests. for line in raw_text.split('\n'): # ... (parsing logic) ... # Example of calling the function directly: # corrected_value = run_grammar_correction(value) # final_corrected = run_gender_correction(grammar_corrected) pass # Placeholder for the full parsing logic from your script # Dummy data to demonstrate PDF generation structured_data['Client Name'] = run_grammar_correction("Morgan & Morgan") structured_data['Intake'] = run_gender_correction(run_grammar_correction("The IME physician asked the examinee if he has any issues sleeping. The examinee replied yes.")) return structured_data class PDF(FPDF): """Custom PDF class with Unicode font support.""" def header(self): self.add_font('DejaVu', 'B', 'DejaVuSans-Bold.ttf', uni=True) self.set_font('DejaVu', 'B', 15) self.cell(0, 10, 'IME WatchDog Report', 0, 1, 'C') self.ln(10) def footer(self): self.set_y(-15) self.set_font('Helvetica', 'I', 8) self.cell(0, 10, f'Page {self.page_no()}', 0, 0, 'C') def generate_pdf(data, output_path): """Generates the final PDF report.""" pdf = PDF() pdf.add_font('DejaVu', '', 'DejaVuSans.ttf', uni=True) pdf.add_page() pdf.set_font('DejaVu', '', 12) for key, value in data.items(): pdf.set_font('DejaVu', 'B', 12) pdf.multi_cell(0, 8, f"{key}:") pdf.set_font('DejaVu', '', 12) pdf.multi_cell(0, 8, str(value)) pdf.ln(4) pdf.output(output_path) print(f"✅ Successfully generated PDF report at: {output_path}") # --- Main Execution --- if __name__ == "__main__": print("--- Starting Local Test Pipeline ---") # 1. Pre-requisite: Make sure models are downloaded. # It's assumed you've run download_models.py script locally first. # 2. Load the models into memory if load_all_models(): # 3. Extract raw text from the input document raw_text = extract_text_from_doc(INPUT_DOC_PATH) if raw_text: # 4. Parse and correct the text corrected_data = parse_and_correct_text(raw_text) # 5. Generate the final PDF report generate_pdf(corrected_data, OUTPUT_PDF_PATH) print("--- Pipeline Finished ---")