Spaces:
Configuration error
Configuration error
| """ Dataset Factory | |
| Hacked together by / Copyright 2021, Ross Wightman | |
| """ | |
| import os | |
| from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder | |
| try: | |
| from torchvision.datasets import Places365 | |
| has_places365 = True | |
| except ImportError: | |
| has_places365 = False | |
| try: | |
| from torchvision.datasets import INaturalist | |
| has_inaturalist = True | |
| except ImportError: | |
| has_inaturalist = False | |
| from .dataset import IterableImageDataset, ImageDataset | |
| _TORCH_BASIC_DS = dict( | |
| cifar10=CIFAR10, | |
| cifar100=CIFAR100, | |
| mnist=MNIST, | |
| qmist=QMNIST, | |
| kmnist=KMNIST, | |
| fashion_mnist=FashionMNIST, | |
| ) | |
| _TRAIN_SYNONYM = dict(train=None, training=None) | |
| _EVAL_SYNONYM = dict(val=None, valid=None, validation=None, eval=None, evaluation=None) | |
| def _search_split(root, split): | |
| # look for sub-folder with name of split in root and use that if it exists | |
| split_name = split.split('[')[0] | |
| try_root = os.path.join(root, split_name) | |
| if os.path.exists(try_root): | |
| return try_root | |
| def _try(syn): | |
| for s in syn: | |
| try_root = os.path.join(root, s) | |
| if os.path.exists(try_root): | |
| return try_root | |
| return root | |
| if split_name in _TRAIN_SYNONYM: | |
| root = _try(_TRAIN_SYNONYM) | |
| elif split_name in _EVAL_SYNONYM: | |
| root = _try(_EVAL_SYNONYM) | |
| return root | |
| def create_dataset( | |
| name, | |
| root, | |
| split='validation', | |
| search_split=True, | |
| class_map=None, | |
| load_bytes=False, | |
| is_training=False, | |
| download=False, | |
| batch_size=None, | |
| repeats=0, | |
| **kwargs | |
| ): | |
| """ Dataset factory method | |
| In parenthesis after each arg are the type of dataset supported for each arg, one of: | |
| * folder - default, timm folder (or tar) based ImageDataset | |
| * torch - torchvision based datasets | |
| * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset | |
| * all - any of the above | |
| Args: | |
| name: dataset name, empty is okay for folder based datasets | |
| root: root folder of dataset (all) | |
| split: dataset split (all) | |
| search_split: search for split specific child fold from root so one can specify | |
| `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder) | |
| class_map: specify class -> index mapping via text file or dict (folder) | |
| load_bytes: load data, return images as undecoded bytes (folder) | |
| download: download dataset if not present and supported (TFDS, torch) | |
| is_training: create dataset in train mode, this is different from the split. | |
| For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS) | |
| batch_size: batch size hint for (TFDS) | |
| repeats: dataset repeats per iteration i.e. epoch (TFDS) | |
| **kwargs: other args to pass to dataset | |
| Returns: | |
| Dataset object | |
| """ | |
| name = name.lower() | |
| if name.startswith('torch/'): | |
| name = name.split('/', 2)[-1] | |
| torch_kwargs = dict(root=root, download=download, **kwargs) | |
| if name in _TORCH_BASIC_DS: | |
| ds_class = _TORCH_BASIC_DS[name] | |
| use_train = split in _TRAIN_SYNONYM | |
| ds = ds_class(train=use_train, **torch_kwargs) | |
| elif name == 'inaturalist' or name == 'inat': | |
| assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist' | |
| target_type = 'full' | |
| split_split = split.split('/') | |
| if len(split_split) > 1: | |
| target_type = split_split[0].split('_') | |
| if len(target_type) == 1: | |
| target_type = target_type[0] | |
| split = split_split[-1] | |
| if split in _TRAIN_SYNONYM: | |
| split = '2021_train' | |
| elif split in _EVAL_SYNONYM: | |
| split = '2021_valid' | |
| ds = INaturalist(version=split, target_type=target_type, **torch_kwargs) | |
| elif name == 'places365': | |
| assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.' | |
| if split in _TRAIN_SYNONYM: | |
| split = 'train-standard' | |
| elif split in _EVAL_SYNONYM: | |
| split = 'val' | |
| ds = Places365(split=split, **torch_kwargs) | |
| elif name == 'imagenet': | |
| if split in _EVAL_SYNONYM: | |
| split = 'val' | |
| ds = ImageNet(split=split, **torch_kwargs) | |
| elif name == 'image_folder' or name == 'folder': | |
| # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason | |
| if search_split and os.path.isdir(root): | |
| # look for split specific sub-folder in root | |
| root = _search_split(root, split) | |
| ds = ImageFolder(root, **kwargs) | |
| else: | |
| assert False, f"Unknown torchvision dataset {name}" | |
| elif name.startswith('tfds/'): | |
| ds = IterableImageDataset( | |
| root, parser=name, split=split, is_training=is_training, | |
| download=download, batch_size=batch_size, repeats=repeats, **kwargs) | |
| else: | |
| # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future | |
| if search_split and os.path.isdir(root): | |
| # look for split specific sub-folder in root | |
| root = _search_split(root, split) | |
| ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs) | |
| return ds | |