""" FastAPI service for Czech text correction pipeline Combines grammar error correction and punctuation restoration """ from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import Optional, List, Dict import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification, pipeline import time import re import logging from contextlib import asynccontextmanager # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables for models gec_model = None gec_tokenizer = None punct_pipeline = None device = None # Optimal hyperparameters for production GEC_CONFIG = { "num_beams": 8, "do_sample": False, "repetition_penalty": 1.0, "length_penalty": 1.0, "no_repeat_ngram_size": 0, "early_stopping": True, "max_new_tokens": 1500 } @asynccontextmanager async def lifespan(app: FastAPI): """Load models on startup, cleanup on shutdown""" global gec_model, gec_tokenizer, punct_pipeline, device logger.info("Loading models...") # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # Load GEC model logger.info("Loading Czech GEC model...") gec_tokenizer = AutoTokenizer.from_pretrained("ufal/byt5-large-geccc-mate") gec_model = AutoModelForSeq2SeqLM.from_pretrained("ufal/byt5-large-geccc-mate") gec_model = gec_model.to(device) logger.info("GEC model loaded successfully") # Load punctuation model logger.info("Loading punctuation model...") punct_tokenizer = AutoTokenizer.from_pretrained("kredor/punctuate-all") punct_model = AutoModelForTokenClassification.from_pretrained("kredor/punctuate-all") punct_pipeline = pipeline( "token-classification", model=punct_model, tokenizer=punct_tokenizer, device=0 if torch.cuda.is_available() else -1 ) logger.info("Punctuation model loaded successfully") logger.info("All models loaded and ready") yield # Cleanup (if needed) logger.info("Shutting down...") # Create FastAPI app with lifespan app = FastAPI( title="Czech Text Correction API", description="API for Czech grammar error correction and punctuation restoration", version="1.0.0", lifespan=lifespan ) # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Request/Response models class CorrectionRequest(BaseModel): text: str = Field(..., max_length=5000, description="Czech text to correct") options: Optional[Dict] = Field(default={}, description="Optional parameters") class CorrectionResponse(BaseModel): success: bool corrected_text: str processing_time_ms: Optional[float] = None error: Optional[str] = None class BatchCorrectionRequest(BaseModel): texts: List[str] = Field(..., max_items=10, description="List of texts to correct") options: Optional[Dict] = Field(default={}, description="Optional parameters") class BatchCorrectionResponse(BaseModel): success: bool corrected_texts: List[str] processing_time_ms: Optional[float] = None error: Optional[str] = None class HealthResponse(BaseModel): status: str models_loaded: bool gpu_available: bool device: str class InfoResponse(BaseModel): name: str version: str models: Dict[str, str] capabilities: List[str] max_input_length: int def apply_gec_correction(text: str) -> str: """Apply grammar error correction to text""" if not text.strip(): return text # Tokenize inputs = gec_tokenizer( text, return_tensors="pt", max_length=1024, truncation=True ) inputs = {k: v.to(device) for k, v in inputs.items()} # Generate correction with torch.no_grad(): outputs = gec_model.generate( **inputs, **GEC_CONFIG ) # Decode corrected = gec_tokenizer.decode(outputs[0], skip_special_tokens=True) return corrected def apply_punctuation(text: str) -> str: """Apply punctuation and capitalization to text""" if not text.strip(): return text # Process with pipeline clean_text = text.lower() results = punct_pipeline(clean_text) # Build punctuation map punct_map = {} current_word = "" current_punct = "" for i, result in enumerate(results): word = result['word'].replace('▁', '').strip() # Map entity labels to punctuation entity = result['entity'] punct_marks = { 'LABEL_0': '', 'LABEL_1': '.', 'LABEL_2': ',', 'LABEL_3': '?', 'LABEL_4': '-', 'LABEL_5': ':' } punct = punct_marks.get(entity, '') # Handle subword tokens if not result['word'].startswith('▁') and i > 0: current_word += word else: if current_word: punct_map[current_word] = current_punct current_word = word current_punct = punct # Add last word if current_word: punct_map[current_word] = current_punct # Reconstruct with punctuation words = clean_text.split() punctuated = [] for word in words: if word in punct_map and punct_map[word]: punctuated.append(word + punct_map[word]) else: punctuated.append(word) # Join and capitalize sentences result = ' '.join(punctuated) # Capitalize first letter and after sentence endings sentences = re.split(r'(?<=[.?!])\s+', result) capitalized = ' '.join(s[0].upper() + s[1:] if s else s for s in sentences) # Clean spacing around punctuation for p in [',', '.', '?', ':', '!', ';']: capitalized = capitalized.replace(f' {p}', p) return capitalized def process_text(text: str) -> str: """Full pipeline: GEC + punctuation""" # Step 1: Grammar correction gec_corrected = apply_gec_correction(text) # Step 2: Punctuation and capitalization final_text = apply_punctuation(gec_corrected) return final_text @app.post("/api/correct", response_model=CorrectionResponse) async def correct_text(request: CorrectionRequest): """ Correct Czech text (grammar + punctuation) """ try: start_time = time.time() # Validate input if not request.text.strip(): raise HTTPException(status_code=400, detail="Text cannot be empty") if len(request.text) > 5000: raise HTTPException(status_code=400, detail="Text too long (max 5000 characters)") # Process text corrected = process_text(request.text) # Calculate processing time processing_time = (time.time() - start_time) * 1000 # Include timing if requested response = CorrectionResponse( success=True, corrected_text=corrected ) if request.options.get("include_timing", False): response.processing_time_ms = processing_time return response except Exception as e: logger.error(f"Error processing text: {str(e)}") return CorrectionResponse( success=False, corrected_text="", error=str(e) ) @app.post("/api/correct/batch", response_model=BatchCorrectionResponse) async def correct_batch(request: BatchCorrectionRequest): """ Correct multiple Czech texts """ try: start_time = time.time() # Validate if not request.texts: raise HTTPException(status_code=400, detail="No texts provided") # Process each text corrected_texts = [] for text in request.texts: if len(text) > 5000: corrected_texts.append(f"[Error: Text too long]") else: corrected = process_text(text) corrected_texts.append(corrected) # Calculate processing time processing_time = (time.time() - start_time) * 1000 response = BatchCorrectionResponse( success=True, corrected_texts=corrected_texts ) if request.options.get("include_timing", False): response.processing_time_ms = processing_time return response except Exception as e: logger.error(f"Error processing batch: {str(e)}") return BatchCorrectionResponse( success=False, corrected_texts=[], error=str(e) ) @app.post("/api/correct/gec-only") async def correct_gec_only(request: CorrectionRequest): """ Apply only grammar error correction (no punctuation) """ try: corrected = apply_gec_correction(request.text) return CorrectionResponse( success=True, corrected_text=corrected ) except Exception as e: return CorrectionResponse( success=False, corrected_text="", error=str(e) ) @app.post("/api/correct/punct-only") async def correct_punct_only(request: CorrectionRequest): """ Apply only punctuation restoration (no grammar correction) """ try: corrected = apply_punctuation(request.text) return CorrectionResponse( success=True, corrected_text=corrected ) except Exception as e: return CorrectionResponse( success=False, corrected_text="", error=str(e) ) @app.get("/api/health", response_model=HealthResponse) async def health_check(): """ Check API health and model status """ models_loaded = (gec_model is not None and punct_pipeline is not None) return HealthResponse( status="healthy" if models_loaded else "loading", models_loaded=models_loaded, gpu_available=torch.cuda.is_available(), device=str(device) if device else "not initialized" ) @app.get("/api/info", response_model=InfoResponse) async def get_info(): """ Get API information and capabilities """ return InfoResponse( name="Czech Text Correction API", version="1.0.0", models={ "gec": "ufal/byt5-large-geccc-mate", "punctuation": "kredor/punctuate-all" }, capabilities=[ "Grammar error correction", "Punctuation restoration", "Capitalization", "Batch processing", "Czech language focus" ], max_input_length=5000 ) @app.get("/") async def root(): """Root endpoint with API documentation link""" return { "message": "Czech Text Correction API", "docs": "/docs", "health": "/api/health", "info": "/api/info" } if __name__ == "__main__": import uvicorn import os port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)