Spaces:
Sleeping
Sleeping
| """ | |
| AI-powered submission analyzer using Hugging Face zero-shot classification. | |
| This module provides free, offline classification without requiring API keys. | |
| """ | |
| from transformers import pipeline | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class SubmissionAnalyzer: | |
| def __init__(self): | |
| """Initialize the zero-shot classification model.""" | |
| self.classifier = None | |
| self.categories = [ | |
| 'Vision', | |
| 'Problem', | |
| 'Objectives', | |
| 'Directives', | |
| 'Values', | |
| 'Actions' | |
| ] | |
| # Category descriptions for better classification | |
| self.category_descriptions = { | |
| 'Vision': 'future aspirations, desired outcomes, what success looks like', | |
| 'Problem': 'current issues, frustrations, causes of problems', | |
| 'Objectives': 'specific goals to achieve', | |
| 'Directives': 'restrictions or requirements for solution design', | |
| 'Values': 'principles or restrictions for setting objectives', | |
| 'Actions': 'concrete steps, interventions, or activities to implement' | |
| } | |
| def _load_model(self): | |
| """Lazy load the model only when needed.""" | |
| if self.classifier is None: | |
| try: | |
| logger.info("Loading zero-shot classification model...") | |
| # Using facebook/bart-large-mnli - good balance of speed and accuracy | |
| self.classifier = pipeline( | |
| "zero-shot-classification", | |
| model="facebook/bart-large-mnli", | |
| device=-1 # Use CPU (-1), change to 0 for GPU | |
| ) | |
| logger.info("Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| raise | |
| def analyze(self, message): | |
| """ | |
| Classify a submission message into one of the predefined categories. | |
| Args: | |
| message (str): The submission message to classify | |
| Returns: | |
| str: The predicted category | |
| """ | |
| self._load_model() | |
| try: | |
| # Use category descriptions as labels for better accuracy | |
| candidate_labels = [ | |
| f"{cat}: {self.category_descriptions[cat]}" | |
| for cat in self.categories | |
| ] | |
| # Run classification | |
| result = self.classifier( | |
| message, | |
| candidate_labels, | |
| multi_label=False | |
| ) | |
| # Extract the category name from the label | |
| top_label = result['labels'][0] | |
| category = top_label.split(':')[0] | |
| logger.info(f"Classified message as: {category} (confidence: {result['scores'][0]:.2f})") | |
| return category | |
| except Exception as e: | |
| logger.error(f"Error analyzing message: {e}") | |
| # Fallback to Problem category if analysis fails | |
| return 'Problem' | |
| def analyze_batch(self, messages): | |
| """ | |
| Classify multiple messages at once. | |
| Args: | |
| messages (list): List of submission messages | |
| Returns: | |
| list: List of predicted categories | |
| """ | |
| return [self.analyze(msg) for msg in messages] | |
| # Global analyzer instance | |
| _analyzer = None | |
| def get_analyzer(): | |
| """Get or create the global analyzer instance.""" | |
| global _analyzer | |
| if _analyzer is None: | |
| _analyzer = SubmissionAnalyzer() | |
| return _analyzer | |