NightPrince commited on
Commit
684a742
Β·
verified Β·
1 Parent(s): a9d338c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -36
app.py CHANGED
@@ -1,13 +1,13 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
 
 
3
  from peft import PeftModel
 
 
4
 
5
- # βœ… Load model and tokenizer
6
- base_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=9)
7
- model = PeftModel.from_pretrained(base_model, "NightPrince/peft-distilbert-toxic-classifier")
8
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
9
-
10
- # βœ… Label mapping
11
  id2label = {
12
  0: "Child Sexual Exploitation",
13
  1: "Elections",
@@ -19,44 +19,97 @@ id2label = {
19
  7: "Violent Crimes",
20
  8: "unsafe"
21
  }
22
- # βœ… Pipeline for easy inference
23
- pipe = pipeline(
24
- "text-classification",
25
- model=model,
26
- tokenizer=tokenizer,
27
- return_all_scores=True
28
- )
29
 
30
- # βœ… Define prediction function
31
- def classify_toxicity(query, image_description):
32
- combined_text = query + " [SEP] " + image_description
33
- preds = pipe(combined_text)[0] # Get scores for all classes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  preds_sorted = sorted(preds, key=lambda x: x['score'], reverse=True)
35
 
36
  top_label = preds_sorted[0]['label']
37
  top_score = preds_sorted[0]['score']
38
-
39
- # Map label ID back to human-readable label
40
  label_id = int(top_label.split("_")[-1]) if "_" in top_label else int(top_label)
41
  final_label = id2label.get(label_id, "Unknown")
42
 
43
- # Display all class scores (optional)
44
  scores_table = "\n".join(
45
  [f"{id2label[int(item['label'].split('_')[-1])]}: {round(item['score']*100, 2)}%" for item in preds]
46
  )
47
 
48
- return f"Top Prediction: {final_label} ({round(top_score*100, 2)}%)\n\nFull Class Scores:\n{scores_table}"
49
-
50
- # βœ… Gradio UI
51
- iface = gr.Interface(
52
- fn=classify_toxicity,
53
- inputs=[
54
- gr.Textbox(label="User Query"),
55
- gr.Textbox(label="Image Description"),
56
- ],
57
- outputs=gr.Textbox(label="Toxicity Prediction"),
58
- title="Toxic Category Classifier (DistilBERT + LoRA)",
59
- description="Enter a user query and image description. The model will classify into one of the 9 toxic categories."
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- iface.launch()
 
1
+ import streamlit as st
2
+ from transformers import (
3
+ AutoTokenizer, AutoModelForSequenceClassification,
4
+ pipeline, BlipProcessor, BlipForConditionalGeneration
5
+ )
6
  from peft import PeftModel
7
+ from PIL import Image
8
+ import requests
9
 
10
+ # 1️⃣ Setup label mapping
 
 
 
 
 
11
  id2label = {
12
  0: "Child Sexual Exploitation",
13
  1: "Elections",
 
19
  7: "Violent Crimes",
20
  8: "unsafe"
21
  }
 
 
 
 
 
 
 
22
 
23
+ # 2️⃣ Load BLIP captioning model
24
+ @st.cache_resource
25
+ def load_caption_model():
26
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
27
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
28
+ return processor, model
29
+
30
+ def caption_image(img):
31
+ processor, model = load_caption_model()
32
+ inputs = processor(images=img, return_tensors="pt")
33
+ out = model.generate(**inputs)
34
+ caption = processor.decode(out[0], skip_special_tokens=True)
35
+ return caption
36
+
37
+ # 3️⃣ Load your DistilBERT+LoRA classifier
38
+ @st.cache_resource
39
+ def load_toxic_classifier():
40
+ base_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=9)
41
+ model = PeftModel.from_pretrained(base_model, "NightPrince/peft-distilbert-toxic-classifier")
42
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
43
+ pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
44
+ return pipe
45
+
46
+ def classify_toxicity(text_input, caption):
47
+ pipe = load_toxic_classifier()
48
+ full_input = text_input + " [SEP] " + caption
49
+ preds = pipe(full_input)[0]
50
  preds_sorted = sorted(preds, key=lambda x: x['score'], reverse=True)
51
 
52
  top_label = preds_sorted[0]['label']
53
  top_score = preds_sorted[0]['score']
 
 
54
  label_id = int(top_label.split("_")[-1]) if "_" in top_label else int(top_label)
55
  final_label = id2label.get(label_id, "Unknown")
56
 
 
57
  scores_table = "\n".join(
58
  [f"{id2label[int(item['label'].split('_')[-1])]}: {round(item['score']*100, 2)}%" for item in preds]
59
  )
60
 
61
+ return final_label, top_score, scores_table
62
+
63
+ # 4️⃣ Llama Guard (real API)
64
+ def llama_guard_filter(text):
65
+ url = "https://openrouter.ai/api/v1/chat/completions"
66
+ headers = {
67
+ "Authorization": f"Bearer {st.secrets['OPENROUTER_API_KEY']}",
68
+ "Content-Type": "application/json"
69
+ }
70
+ data = {
71
+ "model": "meta-llama/llama-guard-3-8b",
72
+ "messages": [
73
+ {"role": "system", "content": "You are a safety classifier. Return only 'safe' or 'unsafe'."},
74
+ {"role": "user", "content": text}
75
+ ]
76
+ }
77
+
78
+ try:
79
+ response = requests.post(url, headers=headers, json=data)
80
+ result = response.json()["choices"][0]["message"]["content"].strip().lower()
81
+ return result
82
+ except Exception as e:
83
+ print("Llama Guard API Error:", e)
84
+ return "safe" # Fallback if error
85
+
86
+ # 🌐 Streamlit UI
87
+ st.set_page_config(page_title="Toxic Moderation System", layout="centered")
88
+ st.title("πŸ›‘οΈ Dual-Stage Toxic Moderation")
89
+ st.markdown("Moderate text and images using **Llama Guard** + **DistilBERT-LoRA**.\n\n- Stage 1: Hard Safety Filter (Llama Guard)\n- Stage 2: Fine Toxic Classifier (LoRA DistilBERT)")
90
+
91
+ text_input = st.text_area("✏️ Enter a text message", height=150)
92
+ uploaded_image = st.file_uploader("πŸ“· Upload an image (optional)", type=["jpg", "jpeg", "png"])
93
+
94
+ image_caption = ""
95
+ if uploaded_image:
96
+ image = Image.open(uploaded_image)
97
+ st.image(image, caption="Uploaded Image", use_column_width=True)
98
+ with st.spinner("πŸ” Generating caption with BLIP..."):
99
+ image_caption = caption_image(image)
100
+ st.success(f"πŸ“ Caption: `{image_caption}`")
101
+
102
+ if st.button("πŸš€ Run Moderation"):
103
+ full_text = text_input + " [SEP] " + image_caption
104
+ with st.spinner("πŸ›‘οΈ Stage 1: Llama Guard..."):
105
+ safety = llama_guard_filter(full_text)
106
+
107
+ if safety == "unsafe":
108
+ st.error("❌ Llama Guard flagged this content as **UNSAFE**.\nModeration stopped.")
109
+ else:
110
+ st.success("βœ… Safe by Llama Guard. Proceeding to classifier...")
111
+ with st.spinner("🧠 Stage 2: DistilBERT Toxic Classifier..."):
112
+ label, score, scores = classify_toxicity(text_input, image_caption)
113
+ st.markdown(f"### πŸ” Prediction: `{label}` ({round(score*100, 2)}%)")
114
+ st.text("πŸ“Š Class Probabilities:\n" + scores)
115