# Copyright (c) OpenMMLab. All rights reserved. from torch.utils.data import ConcatDataset as TorchConcatDataset from xtuner.registry import BUILDER class ConcatDataset(TorchConcatDataset): def __init__(self, datasets): datasets_instance = [] for cfg in datasets: datasets_instance.append(BUILDER.build(cfg)) super().__init__(datasets=datasets_instance) def __repr__(self): main_str = 'Dataset as a concatenation of multiple datasets. \n' main_str += ',\n'.join( [f'{repr(dataset)}' for dataset in self.datasets]) return main_str