import spaces import torch import gradio as gr from docling_core.types.doc import DoclingDocument from docling_core.types.doc.document import DocTagsDocument from transformers import AutoProcessor, AutoModelForVision2Seq from pathlib import Path import tempfile import os import subprocess import sys # Try to install flash-attn at startup if not available try: import flash_attn print("Flash attention already installed") except ImportError: print("Flash attention not found, attempting to install...") try: subprocess.run( [sys.executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"], check=True, capture_output=True, text=True ) print("Flash attention installed successfully") except subprocess.CalledProcessError as e: print(f"Could not install flash attention: {e}") print("Continuing without flash attention...") # Global variables for model and processor model = None processor = None model_loaded = False def load_model(): """Load the model and processor""" global model, processor, model_loaded if not model_loaded: try: # Load processor processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M") # Determine device device = "cuda" if torch.cuda.is_available() else "cpu" # Check if flash attention is available attn_implementation = "eager" # default if device == "cuda": try: import flash_attn attn_implementation = "flash_attention_2" print("Using Flash Attention 2") except ImportError: print("Flash attention not available, using eager attention") attn_implementation = "eager" # Load model with appropriate settings print(f"Loading model on {device} with {attn_implementation}...") if device == "cuda": # For GPU, use bfloat16 for better performance model = AutoModelForVision2Seq.from_pretrained( "ibm-granite/granite-docling-258M", torch_dtype=torch.bfloat16, attn_implementation=attn_implementation, device_map="auto", trust_remote_code=True ) else: # For CPU, use float32 model = AutoModelForVision2Seq.from_pretrained( "ibm-granite/granite-docling-258M", torch_dtype=torch.float32, attn_implementation="eager", trust_remote_code=True ) model = model.to(device) model_loaded = True print(f"Model loaded successfully on {device}") except Exception as e: print(f"Error loading model: {e}") # Fallback loading without special attention try: processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M") model = AutoModelForVision2Seq.from_pretrained( "ibm-granite/granite-docling-258M", torch_dtype=torch.float32, trust_remote_code=True ) device = "cpu" model = model.to(device) model_loaded = True print("Model loaded on CPU as fallback") except Exception as fallback_error: print(f"Fallback loading also failed: {fallback_error}") raise # Load model at startup load_model() @spaces.GPU(duration=120) def process_document_gpu(image, output_format="markdown"): """Process uploaded image to generate Docling document - GPU version""" global model, processor try: # Ensure model is loaded if not model_loaded: load_model() # Move model to GPU if available (for ZeroGPU) device = "cuda" if torch.cuda.is_available() else "cpu" # For ZeroGPU, the model might need to be moved to GPU if device == "cuda": # Only move if not already on cuda if hasattr(model, 'device') and model.device.type != 'cuda': model = model.to(device) print(f"Processing on {device}") # Prepare messages messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "Convert this page to docling."} ] }, ] # Prepare inputs prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=[image], return_tensors="pt") # Move inputs to the same device as the model inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()} # Generate outputs with memory-efficient settings with torch.no_grad(): if device == "cuda": with torch.cuda.amp.autocast(dtype=torch.bfloat16): generated_ids = model.generate( **inputs, max_new_tokens=8192, do_sample=False, temperature=None, top_p=None ) else: generated_ids = model.generate( **inputs, max_new_tokens=8192, do_sample=False ) # Process the output prompt_length = inputs.input_ids.shape[1] trimmed_generated_ids = generated_ids[:, prompt_length:] doctags = processor.batch_decode( trimmed_generated_ids, skip_special_tokens=False, )[0].lstrip() print(f"Generated {len(doctags)} characters of DocTags") # Create Docling document doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image]) doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") # Generate output based on format if output_format == "markdown": content = doc.export_to_markdown() return content, None, None elif output_format == "html": # Create temporary file for HTML with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False) as tmp_file: doc.save_as_html(Path(tmp_file.name)) html_file = tmp_file.name return None, html_file, None else: # Return both formats markdown_content = doc.export_to_markdown() with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False) as tmp_file: doc.save_as_html(Path(tmp_file.name)) html_file = tmp_file.name return markdown_content, html_file, doctags except Exception as e: error_msg = f"Error processing document: {str(e)}" print(error_msg) import traceback print(traceback.format_exc()) return error_msg, None, None def process_document(image, output_format="markdown"): """Wrapper function to handle processing""" if image is None: return "Please upload an image first.", None, None # Call the GPU-decorated function return process_document_gpu(image, output_format) def clear_results(): """Clear all outputs""" return "", None, "" # Create Gradio interface with gr.Blocks( title="Docling Document Converter", theme=gr.themes.Soft(), css=""" .header { text-align: center; margin-bottom: 2rem; } .format-selector { margin-top: 1rem; } .markdown-output { max-height: 600px; overflow-y: auto; padding: 10px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9; } """ ) as demo: gr.Markdown( """ # 📄 Docling Document Converter Upload an image of a document page and convert it to structured markdown or HTML using the IBM Granite-Docling model. This space uses ZeroGPU for efficient processing. The model converts document images into structured formats while preserving layout and formatting. --- """, elem_classes="header" ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( label="Upload Document Image", type="pil", height=400, sources=["upload", "clipboard"], show_label=True ) format_choice = gr.Radio( choices=["markdown", "html", "both"], value="markdown", label="Output Format", info="Choose the output format for the converted document", elem_classes="format-selector" ) with gr.Row(): process_btn = gr.Button( "🚀 Convert Document", variant="primary", size="lg", scale=2 ) clear_btn = gr.Button( "đŸ—‘ī¸ Clear", variant="secondary", size="lg", scale=1 ) # Status indicator gr.Markdown( """ ### â„šī¸ Tips: - Upload clear, high-resolution images for best results - The model works best with text documents, tables, and structured content - Processing may take a few moments depending on document complexity """ ) with gr.Column(scale=2): with gr.Tab("📝 Markdown Output"): markdown_output = gr.Markdown( value="", label="Structured Markdown", show_copy_button=True, elem_classes="markdown-output" ) with gr.Tab("🌐 HTML Output"): html_output = gr.File( label="Download HTML File", file_types=[".html"], visible=True ) with gr.Tab("đŸˇī¸ Raw DocTags"): doctags_output = gr.Textbox( label="Raw DocTags Output", lines=15, max_lines=30, show_copy_button=True, placeholder="Raw DocTags will appear here after processing..." ) # Event handlers process_btn.click( fn=process_document, inputs=[image_input, format_choice], outputs=[markdown_output, html_output, doctags_output], show_progress="full" ) clear_btn.click( fn=clear_results, outputs=[markdown_output, html_output, doctags_output] ) # Examples section with gr.Accordion("📚 Example Documents", open=False): gr.Examples( examples=[ ["https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png"], ], inputs=[image_input], label="Click to load an example document", cache_examples=False ) # Footer gr.Markdown( """ ---

Powered by IBM Granite-Docling-258M

Built with â¤ī¸ using Gradio and Hugging Face Spaces

""" ) # Launch the app if __name__ == "__main__": demo.launch()