| from transformers import PretrainedConfig | |
| from typing import List | |
| class ResnetConfig(PretrainedConfig): | |
| model_type = 'resnet' | |
| def __init__(self, block_type='bottleneck', layers: List[int] = [3, 4, 6, 3], num_classes: int = 1000, **kwargs): | |
| if block_type not in ['basic', 'bottleneck']: | |
| raise ValueError(f"`block` must be 'basic' or bottleneck', got {block_type}.") | |
| self.block_type = block_type | |
| self.layers = layers | |
| self.num_classes = num_classes | |
| super().__init__(**kwargs) | |