Thadillo's picture
First commit.
1c4a712 verified
raw
history blame
3.53 kB
"""
AI-powered submission analyzer using Hugging Face zero-shot classification.
This module provides free, offline classification without requiring API keys.
"""
from transformers import pipeline
import logging
logger = logging.getLogger(__name__)
class SubmissionAnalyzer:
def __init__(self):
"""Initialize the zero-shot classification model."""
self.classifier = None
self.categories = [
'Vision',
'Problem',
'Objectives',
'Directives',
'Values',
'Actions'
]
# Category descriptions for better classification
self.category_descriptions = {
'Vision': 'future aspirations, desired outcomes, what success looks like',
'Problem': 'current issues, frustrations, causes of problems',
'Objectives': 'specific goals to achieve',
'Directives': 'restrictions or requirements for solution design',
'Values': 'principles or restrictions for setting objectives',
'Actions': 'concrete steps, interventions, or activities to implement'
}
def _load_model(self):
"""Lazy load the model only when needed."""
if self.classifier is None:
try:
logger.info("Loading zero-shot classification model...")
# Using facebook/bart-large-mnli - good balance of speed and accuracy
self.classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device=-1 # Use CPU (-1), change to 0 for GPU
)
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def analyze(self, message):
"""
Classify a submission message into one of the predefined categories.
Args:
message (str): The submission message to classify
Returns:
str: The predicted category
"""
self._load_model()
try:
# Use category descriptions as labels for better accuracy
candidate_labels = [
f"{cat}: {self.category_descriptions[cat]}"
for cat in self.categories
]
# Run classification
result = self.classifier(
message,
candidate_labels,
multi_label=False
)
# Extract the category name from the label
top_label = result['labels'][0]
category = top_label.split(':')[0]
logger.info(f"Classified message as: {category} (confidence: {result['scores'][0]:.2f})")
return category
except Exception as e:
logger.error(f"Error analyzing message: {e}")
# Fallback to Problem category if analysis fails
return 'Problem'
def analyze_batch(self, messages):
"""
Classify multiple messages at once.
Args:
messages (list): List of submission messages
Returns:
list: List of predicted categories
"""
return [self.analyze(msg) for msg in messages]
# Global analyzer instance
_analyzer = None
def get_analyzer():
"""Get or create the global analyzer instance."""
global _analyzer
if _analyzer is None:
_analyzer = SubmissionAnalyzer()
return _analyzer