|
|
import gradio as gr |
|
|
import json |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
def load_model(): |
|
|
"""Load the fine-tuned model and tokenizer.""" |
|
|
global model, tokenizer |
|
|
|
|
|
if model is not None and tokenizer is not None: |
|
|
return model, tokenizer |
|
|
|
|
|
print("π Loading fine-tuned model...") |
|
|
|
|
|
|
|
|
base_model_id = "HuggingFaceTB/SmolLM-360M" |
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_id) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
base_model_id, |
|
|
torch_dtype=torch.float32, |
|
|
) |
|
|
|
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, "waliaMuskaan011/calendar-event-extractor-smollm") |
|
|
|
|
|
print("β
Model loaded successfully!") |
|
|
return model, tokenizer |
|
|
|
|
|
def extract_json_from_text(text): |
|
|
"""Extract the first JSON object from text.""" |
|
|
try: |
|
|
|
|
|
start = text.find('{') |
|
|
if start == -1: |
|
|
return None |
|
|
|
|
|
depth = 0 |
|
|
for i in range(start, len(text)): |
|
|
if text[i] == '{': |
|
|
depth += 1 |
|
|
elif text[i] == '}': |
|
|
depth -= 1 |
|
|
if depth == 0: |
|
|
json_str = text[start:i+1] |
|
|
return json.loads(json_str) |
|
|
return None |
|
|
except (json.JSONDecodeError, TypeError, ValueError): |
|
|
return None |
|
|
|
|
|
def predict_calendar_event(event_text): |
|
|
"""Extract calendar information from event text.""" |
|
|
if not event_text.strip(): |
|
|
return "Please enter some text describing a calendar event." |
|
|
|
|
|
try: |
|
|
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
|
|
|
prompt = f"Extract calendar information from: {event_text}\nCalendar JSON:" |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", padding=True) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
inputs.input_ids, |
|
|
attention_mask=inputs.attention_mask, |
|
|
max_new_tokens=150, |
|
|
do_sample=False, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
generated_text = full_response[len(prompt):].strip() |
|
|
|
|
|
|
|
|
extracted_json = extract_json_from_text(generated_text) |
|
|
|
|
|
if extracted_json: |
|
|
return json.dumps(extracted_json, indent=2, ensure_ascii=False) |
|
|
else: |
|
|
return f"Could not extract valid JSON. Raw output: {generated_text[:200]}..." |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error processing request: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Calendar Event Extractor", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# π
Calendar Event Extractor |
|
|
|
|
|
This AI model extracts structured calendar information from natural language text. |
|
|
Powered by fine-tuned SmolLM-360M with LoRA adapters. |
|
|
|
|
|
**Try it out**: Enter any calendar-related text and get structured JSON output! |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_text = gr.Textbox( |
|
|
label="π Event Description", |
|
|
placeholder="e.g., 'Meeting with John tomorrow at 2pm for 1 hour'", |
|
|
lines=3 |
|
|
) |
|
|
extract_btn = gr.Button("π Extract Event Info", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_json = gr.Textbox( |
|
|
label="π Extracted Information (JSON)", |
|
|
lines=10, |
|
|
max_lines=15 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### π Try these examples:") |
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
["Quick meeting at the coworking space on 10th May 2025 starting at 11:00 am for 45 minutes"], |
|
|
["Coffee chat with Sarah tomorrow at 3pm"], |
|
|
["Weekly standup every Monday at 9am on Zoom"], |
|
|
["Doctor appointment next Friday at 2:30 PM for 30 minutes"], |
|
|
["Team lunch at the new restaurant on 15th December"], |
|
|
["Call with client on 25/12/2024 at 10:00 AM, needs to discuss project timeline"], |
|
|
], |
|
|
inputs=[input_text], |
|
|
outputs=[output_json], |
|
|
fn=predict_calendar_event, |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
extract_btn.click( |
|
|
fn=predict_calendar_event, |
|
|
inputs=[input_text], |
|
|
outputs=[output_json] |
|
|
) |
|
|
|
|
|
gr.Markdown(f""" |
|
|
--- |
|
|
**Model Details**: Fine-tuned SmolLM-360M using LoRA β’ **Dataset**: ~2500 calendar events β’ **Training**: Custom augmentation pipeline |
|
|
|
|
|
[π Model Card](https://huggingface.co/waliaMuskaan011/calendar-event-extractor-smollm) β’ [π» Training Code](https://github.com/muskaanwalia098/Calendar-Event-Entity-Extraction) |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|