|
|
import itertools |
|
|
from collections import defaultdict |
|
|
from typing import List, Optional, Sequence |
|
|
|
|
|
import torch |
|
|
from mmengine.evaluator import BaseMetric |
|
|
from mmengine.logging import print_log |
|
|
from rich.console import Console |
|
|
from rich.table import Table |
|
|
|
|
|
|
|
|
class RewardMetric(BaseMetric): |
|
|
r"""Reward model evaluation metric. |
|
|
""" |
|
|
default_prefix: Optional[str] = '' |
|
|
|
|
|
def __init__(self, |
|
|
collect_device: str = 'cpu', |
|
|
prefix: Optional[str] = None) -> None: |
|
|
super().__init__(collect_device=collect_device, prefix=prefix) |
|
|
|
|
|
def process(self, data_batch, data_samples: Sequence[dict]): |
|
|
"""Process one batch of data samples. |
|
|
|
|
|
The processed results should be stored in ``self.results``, which will |
|
|
be used to computed the metrics when all batches have been processed. |
|
|
|
|
|
Args: |
|
|
data_batch: A batch of data from the dataloader. |
|
|
data_samples (Sequence[dict]): A batch of outputs from the model. |
|
|
""" |
|
|
logits = torch.cat( |
|
|
[sample['logits'].unsqueeze(0) for sample in data_samples], dim=0) |
|
|
labels = data_batch['data']['labels'] |
|
|
ds_names = data_batch['data_samples']['ds_names'] |
|
|
chosen_idx = torch.where(labels == 0) |
|
|
rejected_idx = torch.where(labels == 1) |
|
|
chosen_logits = logits[chosen_idx].cpu() |
|
|
rejected_logits = logits[rejected_idx].cpu() |
|
|
|
|
|
correct = (chosen_logits > rejected_logits).cpu() |
|
|
self.results.append({ |
|
|
'chosen_logits': chosen_logits, |
|
|
'rejected_logits': rejected_logits, |
|
|
'correct': correct, |
|
|
'ds_names': ds_names |
|
|
}) |
|
|
|
|
|
def compute_metrics(self, results: List): |
|
|
"""Compute the metrics from processed results. |
|
|
|
|
|
Args: |
|
|
results (dict): The processed results of each batch. |
|
|
|
|
|
Returns: |
|
|
Dict: The computed metrics. The keys are the names of the metrics, |
|
|
and the values are corresponding results. |
|
|
""" |
|
|
|
|
|
metrics = {} |
|
|
|
|
|
correct = torch.cat([res['correct'] for res in results]) |
|
|
chosen_logits = torch.cat([res['chosen_logits'] for res in results]) |
|
|
rejected_logits = torch.cat( |
|
|
[res['rejected_logits'] for res in results]) |
|
|
ds_names = list(itertools.chain(*[res['ds_names'] for res in results])) |
|
|
|
|
|
|
|
|
grouped_correct = defaultdict(list) |
|
|
grouped_chosen_logits = defaultdict(list) |
|
|
grouped_rejected_logits = defaultdict(list) |
|
|
for i, ds_name in enumerate(ds_names): |
|
|
grouped_correct[ds_name].append(correct[i]) |
|
|
grouped_chosen_logits[ds_name].append(chosen_logits[i]) |
|
|
grouped_rejected_logits[ds_name].append(rejected_logits[i]) |
|
|
|
|
|
|
|
|
table = Table(title='Reward Metrics') |
|
|
table.add_column('Dataset Name') |
|
|
table.add_column('Accuracy') |
|
|
table.add_column('Chosen Score') |
|
|
table.add_column('Rejected Score') |
|
|
|
|
|
for ds_name in grouped_correct.keys(): |
|
|
correct = torch.stack(grouped_correct[ds_name]) |
|
|
chosen_logits = torch.stack(grouped_chosen_logits[ds_name]) |
|
|
rejected_logits = torch.stack(grouped_rejected_logits[ds_name]) |
|
|
|
|
|
acc = correct.float().mean() |
|
|
metrics[f'accuracy/{ds_name}'] = acc.item() |
|
|
metrics[f'chosen_score/{ds_name}'] = chosen_logits.mean().item() |
|
|
metrics[f'rejected_score{ds_name}'] = rejected_logits.mean().item() |
|
|
|
|
|
table.add_row(ds_name, f'{acc:.4f}', f'{chosen_logits.mean():.4f}', |
|
|
f'{rejected_logits.mean():.4f}') |
|
|
|
|
|
console = Console() |
|
|
with console.capture() as capture: |
|
|
console.print(table, end='') |
|
|
print_log('\n' + capture.get(), 'current') |
|
|
|
|
|
return metrics |
|
|
|