|
|
""" |
|
|
FastAPI service for Czech text correction pipeline |
|
|
Combines grammar error correction and punctuation restoration |
|
|
""" |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import Optional, List, Dict |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification, pipeline |
|
|
import time |
|
|
import re |
|
|
import logging |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
gec_model = None |
|
|
gec_tokenizer = None |
|
|
punct_pipeline = None |
|
|
device = None |
|
|
|
|
|
|
|
|
GEC_CONFIG = { |
|
|
"num_beams": 8, |
|
|
"do_sample": False, |
|
|
"repetition_penalty": 1.0, |
|
|
"length_penalty": 1.0, |
|
|
"no_repeat_ngram_size": 0, |
|
|
"early_stopping": True, |
|
|
"max_new_tokens": 1500 |
|
|
} |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Load models on startup, cleanup on shutdown""" |
|
|
global gec_model, gec_tokenizer, punct_pipeline, device |
|
|
|
|
|
logger.info("Loading models...") |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
|
|
|
logger.info("Loading Czech GEC model...") |
|
|
gec_tokenizer = AutoTokenizer.from_pretrained("ufal/byt5-large-geccc-mate") |
|
|
gec_model = AutoModelForSeq2SeqLM.from_pretrained("ufal/byt5-large-geccc-mate") |
|
|
gec_model = gec_model.to(device) |
|
|
logger.info("GEC model loaded successfully") |
|
|
|
|
|
|
|
|
logger.info("Loading punctuation model...") |
|
|
punct_tokenizer = AutoTokenizer.from_pretrained("kredor/punctuate-all") |
|
|
punct_model = AutoModelForTokenClassification.from_pretrained("kredor/punctuate-all") |
|
|
punct_pipeline = pipeline( |
|
|
"token-classification", |
|
|
model=punct_model, |
|
|
tokenizer=punct_tokenizer, |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
logger.info("Punctuation model loaded successfully") |
|
|
|
|
|
logger.info("All models loaded and ready") |
|
|
|
|
|
yield |
|
|
|
|
|
|
|
|
logger.info("Shutting down...") |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Czech Text Correction API", |
|
|
description="API for Czech grammar error correction and punctuation restoration", |
|
|
version="1.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
class CorrectionRequest(BaseModel): |
|
|
text: str = Field(..., max_length=5000, description="Czech text to correct") |
|
|
options: Optional[Dict] = Field(default={}, description="Optional parameters") |
|
|
|
|
|
class CorrectionResponse(BaseModel): |
|
|
success: bool |
|
|
corrected_text: str |
|
|
processing_time_ms: Optional[float] = None |
|
|
error: Optional[str] = None |
|
|
|
|
|
class BatchCorrectionRequest(BaseModel): |
|
|
texts: List[str] = Field(..., max_items=10, description="List of texts to correct") |
|
|
options: Optional[Dict] = Field(default={}, description="Optional parameters") |
|
|
|
|
|
class BatchCorrectionResponse(BaseModel): |
|
|
success: bool |
|
|
corrected_texts: List[str] |
|
|
processing_time_ms: Optional[float] = None |
|
|
error: Optional[str] = None |
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str |
|
|
models_loaded: bool |
|
|
gpu_available: bool |
|
|
device: str |
|
|
|
|
|
class InfoResponse(BaseModel): |
|
|
name: str |
|
|
version: str |
|
|
models: Dict[str, str] |
|
|
capabilities: List[str] |
|
|
max_input_length: int |
|
|
|
|
|
def apply_gec_correction(text: str) -> str: |
|
|
"""Apply grammar error correction to text""" |
|
|
if not text.strip(): |
|
|
return text |
|
|
|
|
|
|
|
|
inputs = gec_tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
max_length=1024, |
|
|
truncation=True |
|
|
) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = gec_model.generate( |
|
|
**inputs, |
|
|
**GEC_CONFIG |
|
|
) |
|
|
|
|
|
|
|
|
corrected = gec_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return corrected |
|
|
|
|
|
def apply_punctuation(text: str) -> str: |
|
|
"""Apply punctuation and capitalization to text""" |
|
|
if not text.strip(): |
|
|
return text |
|
|
|
|
|
|
|
|
clean_text = text.lower() |
|
|
results = punct_pipeline(clean_text) |
|
|
|
|
|
|
|
|
punct_map = {} |
|
|
current_word = "" |
|
|
current_punct = "" |
|
|
|
|
|
for i, result in enumerate(results): |
|
|
word = result['word'].replace('▁', '').strip() |
|
|
|
|
|
|
|
|
entity = result['entity'] |
|
|
punct_marks = { |
|
|
'LABEL_0': '', |
|
|
'LABEL_1': '.', |
|
|
'LABEL_2': ',', |
|
|
'LABEL_3': '?', |
|
|
'LABEL_4': '-', |
|
|
'LABEL_5': ':' |
|
|
} |
|
|
punct = punct_marks.get(entity, '') |
|
|
|
|
|
|
|
|
if not result['word'].startswith('▁') and i > 0: |
|
|
current_word += word |
|
|
else: |
|
|
if current_word: |
|
|
punct_map[current_word] = current_punct |
|
|
current_word = word |
|
|
current_punct = punct |
|
|
|
|
|
|
|
|
if current_word: |
|
|
punct_map[current_word] = current_punct |
|
|
|
|
|
|
|
|
words = clean_text.split() |
|
|
punctuated = [] |
|
|
|
|
|
for word in words: |
|
|
if word in punct_map and punct_map[word]: |
|
|
punctuated.append(word + punct_map[word]) |
|
|
else: |
|
|
punctuated.append(word) |
|
|
|
|
|
|
|
|
result = ' '.join(punctuated) |
|
|
|
|
|
|
|
|
sentences = re.split(r'(?<=[.?!])\s+', result) |
|
|
capitalized = ' '.join(s[0].upper() + s[1:] if s else s for s in sentences) |
|
|
|
|
|
|
|
|
for p in [',', '.', '?', ':', '!', ';']: |
|
|
capitalized = capitalized.replace(f' {p}', p) |
|
|
|
|
|
return capitalized |
|
|
|
|
|
def process_text(text: str) -> str: |
|
|
"""Full pipeline: GEC + punctuation""" |
|
|
|
|
|
gec_corrected = apply_gec_correction(text) |
|
|
|
|
|
|
|
|
final_text = apply_punctuation(gec_corrected) |
|
|
|
|
|
return final_text |
|
|
|
|
|
@app.post("/api/correct", response_model=CorrectionResponse) |
|
|
async def correct_text(request: CorrectionRequest): |
|
|
""" |
|
|
Correct Czech text (grammar + punctuation) |
|
|
""" |
|
|
try: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if not request.text.strip(): |
|
|
raise HTTPException(status_code=400, detail="Text cannot be empty") |
|
|
|
|
|
if len(request.text) > 5000: |
|
|
raise HTTPException(status_code=400, detail="Text too long (max 5000 characters)") |
|
|
|
|
|
|
|
|
corrected = process_text(request.text) |
|
|
|
|
|
|
|
|
processing_time = (time.time() - start_time) * 1000 |
|
|
|
|
|
|
|
|
response = CorrectionResponse( |
|
|
success=True, |
|
|
corrected_text=corrected |
|
|
) |
|
|
|
|
|
if request.options.get("include_timing", False): |
|
|
response.processing_time_ms = processing_time |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing text: {str(e)}") |
|
|
return CorrectionResponse( |
|
|
success=False, |
|
|
corrected_text="", |
|
|
error=str(e) |
|
|
) |
|
|
|
|
|
@app.post("/api/correct/batch", response_model=BatchCorrectionResponse) |
|
|
async def correct_batch(request: BatchCorrectionRequest): |
|
|
""" |
|
|
Correct multiple Czech texts |
|
|
""" |
|
|
try: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if not request.texts: |
|
|
raise HTTPException(status_code=400, detail="No texts provided") |
|
|
|
|
|
|
|
|
corrected_texts = [] |
|
|
for text in request.texts: |
|
|
if len(text) > 5000: |
|
|
corrected_texts.append(f"[Error: Text too long]") |
|
|
else: |
|
|
corrected = process_text(text) |
|
|
corrected_texts.append(corrected) |
|
|
|
|
|
|
|
|
processing_time = (time.time() - start_time) * 1000 |
|
|
|
|
|
response = BatchCorrectionResponse( |
|
|
success=True, |
|
|
corrected_texts=corrected_texts |
|
|
) |
|
|
|
|
|
if request.options.get("include_timing", False): |
|
|
response.processing_time_ms = processing_time |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing batch: {str(e)}") |
|
|
return BatchCorrectionResponse( |
|
|
success=False, |
|
|
corrected_texts=[], |
|
|
error=str(e) |
|
|
) |
|
|
|
|
|
@app.post("/api/correct/gec-only") |
|
|
async def correct_gec_only(request: CorrectionRequest): |
|
|
""" |
|
|
Apply only grammar error correction (no punctuation) |
|
|
""" |
|
|
try: |
|
|
corrected = apply_gec_correction(request.text) |
|
|
return CorrectionResponse( |
|
|
success=True, |
|
|
corrected_text=corrected |
|
|
) |
|
|
except Exception as e: |
|
|
return CorrectionResponse( |
|
|
success=False, |
|
|
corrected_text="", |
|
|
error=str(e) |
|
|
) |
|
|
|
|
|
@app.post("/api/correct/punct-only") |
|
|
async def correct_punct_only(request: CorrectionRequest): |
|
|
""" |
|
|
Apply only punctuation restoration (no grammar correction) |
|
|
""" |
|
|
try: |
|
|
corrected = apply_punctuation(request.text) |
|
|
return CorrectionResponse( |
|
|
success=True, |
|
|
corrected_text=corrected |
|
|
) |
|
|
except Exception as e: |
|
|
return CorrectionResponse( |
|
|
success=False, |
|
|
corrected_text="", |
|
|
error=str(e) |
|
|
) |
|
|
|
|
|
@app.get("/api/health", response_model=HealthResponse) |
|
|
async def health_check(): |
|
|
""" |
|
|
Check API health and model status |
|
|
""" |
|
|
models_loaded = (gec_model is not None and punct_pipeline is not None) |
|
|
|
|
|
return HealthResponse( |
|
|
status="healthy" if models_loaded else "loading", |
|
|
models_loaded=models_loaded, |
|
|
gpu_available=torch.cuda.is_available(), |
|
|
device=str(device) if device else "not initialized" |
|
|
) |
|
|
|
|
|
@app.get("/api/info", response_model=InfoResponse) |
|
|
async def get_info(): |
|
|
""" |
|
|
Get API information and capabilities |
|
|
""" |
|
|
return InfoResponse( |
|
|
name="Czech Text Correction API", |
|
|
version="1.0.0", |
|
|
models={ |
|
|
"gec": "ufal/byt5-large-geccc-mate", |
|
|
"punctuation": "kredor/punctuate-all" |
|
|
}, |
|
|
capabilities=[ |
|
|
"Grammar error correction", |
|
|
"Punctuation restoration", |
|
|
"Capitalization", |
|
|
"Batch processing", |
|
|
"Czech language focus" |
|
|
], |
|
|
max_input_length=5000 |
|
|
) |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint with API documentation link""" |
|
|
return { |
|
|
"message": "Czech Text Correction API", |
|
|
"docs": "/docs", |
|
|
"health": "/api/health", |
|
|
"info": "/api/info" |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
import os |
|
|
port = int(os.environ.get("PORT", 7860)) |
|
|
uvicorn.run(app, host="0.0.0.0", port=port) |