# Copyright (c) OpenMMLab. All rights reserved. from typing import Any, Sequence import numpy as np import torch from mmengine.evaluator import BaseMetric from mmengine.logging import print_log from rich.console import Console from rich.table import Table from xtuner.registry import BUILDER class MMLUMetric(BaseMetric): METAINFO = { 'subcategories': { 'abstract_algebra': ['math'], 'anatomy': ['health'], 'astronomy': ['physics'], 'business_ethics': ['business'], 'clinical_knowledge': ['health'], 'college_biology': ['biology'], 'college_chemistry': ['chemistry'], 'college_computer_science': ['computer science'], 'college_mathematics': ['math'], 'college_medicine': ['health'], 'college_physics': ['physics'], 'computer_security': ['computer science'], 'conceptual_physics': ['physics'], 'econometrics': ['economics'], 'electrical_engineering': ['engineering'], 'elementary_mathematics': ['math'], 'formal_logic': ['philosophy'], 'global_facts': ['other'], 'high_school_biology': ['biology'], 'high_school_chemistry': ['chemistry'], 'high_school_computer_science': ['computer science'], 'high_school_european_history': ['history'], 'high_school_geography': ['geography'], 'high_school_government_and_politics': ['politics'], 'high_school_macroeconomics': ['economics'], 'high_school_mathematics': ['math'], 'high_school_microeconomics': ['economics'], 'high_school_physics': ['physics'], 'high_school_psychology': ['psychology'], 'high_school_statistics': ['math'], 'high_school_us_history': ['history'], 'high_school_world_history': ['history'], 'human_aging': ['health'], 'human_sexuality': ['culture'], 'international_law': ['law'], 'jurisprudence': ['law'], 'logical_fallacies': ['philosophy'], 'machine_learning': ['computer science'], 'management': ['business'], 'marketing': ['business'], 'medical_genetics': ['health'], 'miscellaneous': ['other'], 'moral_disputes': ['philosophy'], 'moral_scenarios': ['philosophy'], 'nutrition': ['health'], 'philosophy': ['philosophy'], 'prehistory': ['history'], 'professional_accounting': ['other'], 'professional_law': ['law'], 'professional_medicine': ['health'], 'professional_psychology': ['psychology'], 'public_relations': ['politics'], 'security_studies': ['politics'], 'sociology': ['culture'], 'us_foreign_policy': ['politics'], 'virology': ['health'], 'world_religions': ['philosophy'], }, 'categories': { 'STEM': [ 'physics', 'chemistry', 'biology', 'computer science', 'math', 'engineering' ], 'humanities': ['history', 'philosophy', 'law'], 'social sciences': ['politics', 'culture', 'economics', 'geography', 'psychology'], 'other (business, health, misc.)': ['other', 'business', 'health'], }, } METAINFO['subcategories_list'] = list({ subcat for subcats in METAINFO['subcategories'].values() for subcat in subcats }) def __init__(self, tokenizer, *args, **kwargs): super().__init__(*args, **kwargs) tokenizer = BUILDER.build(tokenizer) self.abcd_idx = [ tokenizer.encode('A', add_special_tokens=False)[0], tokenizer.encode('B', add_special_tokens=False)[0], tokenizer.encode('C', add_special_tokens=False)[0], tokenizer.encode('D', add_special_tokens=False)[0], ] @staticmethod def ABCD_to_0123(abcd): return {'A': 0, 'B': 1, 'C': 2, 'D': 3}[abcd] @staticmethod def find_first_zero_index(tensor): indices = torch.nonzero(tensor == 0) if indices.numel() > 0: return indices[0].item() else: return None @staticmethod def accuracy(preds, gts): """Computes the accuracy for preds and gts.""" correct = [1 if pred == gt else 0 for pred, gt in zip(preds, gts)] acc = np.mean(correct) * 100 return acc def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: data_batch (Any): A batch of data from the dataloader. data_samples (Sequence[dict]): A batch of outputs from the model. """ subjects = data_batch['data_samples']['subjects'] gts = [ self.ABCD_to_0123(gt) for gt in data_batch['data_samples']['labels'] ] preds = [] for sample, attn_mask, subject, gt in zip( data_samples, data_batch['data']['attention_mask'], subjects, gts): pred_logits = sample['logits'] first_zero_idx = self.find_first_zero_index(attn_mask) pred_idx = -1 if first_zero_idx is None else first_zero_idx - 1 pred_logtis_abcd = pred_logits[pred_idx, self.abcd_idx] pred = torch.argmax(pred_logtis_abcd).item() preds.append(pred) self.results.append((subject, pred, gt)) def compute_metrics(self, results: list) -> dict: """Compute the metrics from processed results. Args: results (list): The processed results of each batch. Returns: dict: The computed metrics. The keys are the names of the metrics, and the values are corresponding results. """ subjects_results = { subject: { 'preds': [], 'gts': [] } for subject in self.METAINFO['subcategories'].keys() } subcats_results = { subcat: { 'preds': [], 'gts': [] } for subcat in self.METAINFO['subcategories_list'] } cats_results = { cat: { 'preds': [], 'gts': [] } for cat in self.METAINFO['categories'].keys() } for subject, pred, gt in results: subjects_results[subject]['preds'].append(pred) subjects_results[subject]['gts'].append(gt) subcats = self.METAINFO['subcategories'][subject] for subcat in subcats: subcats_results[subcat]['preds'].append(pred) subcats_results[subcat]['gts'].append(gt) for cat, subcats in self.METAINFO['categories'].items(): for subcat in subcats: if subcat in subcats_results: cats_results[cat]['preds'].extend( subcats_results[subcat]['preds']) cats_results[cat]['gts'].extend( subcats_results[subcat]['gts']) subjects_metrics = dict() subcats_metrics = dict() cats_metrics = dict() for subject in self.METAINFO['subcategories'].keys(): assert len(subjects_results[subject]['preds']) == len( subjects_results[subject]['gts']) if len(subjects_results[subject]['preds']) == 0: print_log(f'Skip subject {subject} for mmlu', 'current') else: score = self.accuracy(subjects_results[subject]['preds'], subjects_results[subject]['gts']) subjects_metrics[f'{subject}'] = score for subcat in self.METAINFO['subcategories_list']: assert len(subcats_results[subcat]['preds']) == len( subcats_results[subcat]['gts']) if len(subcats_results[subcat]['preds']) == 0: print_log(f'Skip subcategory {subcat} for mmlu', 'current') else: score = self.accuracy(subcats_results[subcat]['preds'], subcats_results[subcat]['gts']) subcats_metrics[f'{subcat}'] = score for cat in self.METAINFO['categories'].keys(): assert len(cats_results[cat]['preds']) == len( cats_results[cat]['gts']) if len(cats_results[cat]['preds']) == 0: print_log(f'Skip category {cat} for mmlu', 'current') else: score = self.accuracy(cats_results[cat]['preds'], cats_results[cat]['gts']) cats_metrics[f'{cat}'] = score metrics = dict() metrics.update(subjects_metrics) metrics.update(subcats_metrics) metrics.update(cats_metrics) metrics['average'] = np.mean(list(subjects_metrics.values())) table_metrics = dict() table_metrics.update(cats_metrics) table_metrics['average'] = np.mean(list(subjects_metrics.values())) self._print_results(table_metrics) return metrics def _print_results(self, table_metrics: dict) -> None: table_title = ' MMLU Benchmark ' table = Table(title=table_title) console = Console() table.add_column('Categories', justify='left') table.add_column('Accuracy (%)', justify='right') for cat, acc in table_metrics.items(): table.add_row(cat, f'{acc:.1f}') with console.capture() as capture: console.print(table, end='') print_log('\n' + capture.get(), 'current')