| # import json | |
| # import logging | |
| # import os | |
| # import numpy as np | |
| # import pandas as pd | |
| # import torch | |
| # import h5py | |
| # from datasets import Dataset as HFDataset | |
| # from datasets import DatasetDict, load_from_disk | |
| # from mmengine import print_log | |
| # from torch.utils.data import Dataset, get_worker_info | |
| # from xtuner.registry import BUILDER | |
| # from .huggingface import process_hf_dataset | |
| # # 映射采样策略到共享整数,便于多进程同步 | |
| # _STRATEGY2ID = {"linspace": 0, "random": 1, "random_full": 2} | |
| # _ID2STRATEGY = {v: k for k, v in _STRATEGY2ID.items()} | |
| # class LLaVADataset(Dataset): | |
| # def __init__(self, | |
| # image_folder, | |
| # image_path_list, | |
| # per_image_length, | |
| # data_path=None, | |
| # tokenizer=None, | |
| # offline_processed_text_folder=None, | |
| # max_dataset_length=None, | |
| # dataset_map_fn=None, | |
| # template_map_fn=None, | |
| # max_length=2048, | |
| # pad_image_to_square=False, | |
| # sample_num=10240, | |
| # image_feature_prefix='', | |
| # identifier='', | |
| # image_feature_suffix='.pt', | |
| # unwanted_prefix_csv=None, | |
| # sample_strategy: str = 'linspace', # 新增:默认等距 | |
| # # ---------- DEBUG 选项 ---------- | |
| # debug_max_samples=None, | |
| # debug_ratio=None, | |
| # debug_shuffle=True, | |
| # debug_seed=3407, | |
| # debug_include_ids=None): | |
| # super().__init__() | |
| # # ---- 通过共享内存暴露可变控制量,确保 Hook 在主进程修改后,worker 可见 ---- | |
| # self._sample_num_shm = torch.tensor([int(sample_num)], dtype=torch.int32) | |
| # self._sample_num_shm.share_memory_() | |
| # self._pil_shm = torch.tensor([int(per_image_length)], dtype=torch.int32) | |
| # self._pil_shm.share_memory_() | |
| # if sample_strategy not in _STRATEGY2ID: | |
| # raise ValueError(f"Unsupported sample_strategy: {sample_strategy}") | |
| # self._strategy_shm = torch.tensor([_STRATEGY2ID[sample_strategy]], dtype=torch.int32) | |
| # self._strategy_shm.share_memory_() | |
| # self.pad_image_to_square = pad_image_to_square | |
| # self.image_feature_prefix = image_feature_prefix | |
| # self.identifier = identifier | |
| # # debug opts | |
| # self._dbg_max = debug_max_samples | |
| # self._dbg_ratio = debug_ratio | |
| # self._dbg_shuffle = debug_shuffle | |
| # self._dbg_seed = int(debug_seed) | |
| # self._dbg_include_ids = set(debug_include_ids) if debug_include_ids else None | |
| # assert offline_processed_text_folder or (data_path and tokenizer) | |
| # if offline_processed_text_folder and data_path: | |
| # print_log( | |
| # 'Both `offline_processed_text_folder` and `data_path` are set, ' | |
| # 'and we load dataset from `offline_processed_text_folder` ' | |
| # f'({offline_processed_text_folder})', | |
| # logger='current', level=logging.WARNING) | |
| # # ---------------------- load text ---------------------- | |
| # if offline_processed_text_folder is not None: | |
| # ds = load_from_disk(offline_processed_text_folder) | |
| # if isinstance(ds, DatasetDict): | |
| # ds = ds.get('train', None) or next(iter(ds.values())) | |
| # assert isinstance(ds, HFDataset) | |
| # text_ds = ds | |
| # text_ds = self._apply_debug_subset_to_hf(text_ds) | |
| # self.text_data = text_ds | |
| # else: | |
| # if data_path.endswith('.json'): | |
| # json_data = json.load(open(data_path)) | |
| # elif data_path.endswith('.jsonl'): | |
| # json_data = self._load_jsonl(data_path) | |
| # else: | |
| # raise NotImplementedError | |
| # # ---- filter out unwanted prefixes (string/list 都兼容) | |
| # unwanted_prefixes = self._load_unwanted_prefixes(unwanted_prefix_csv) | |
| # original_count = len(json_data) | |
| # filtered = [] | |
| # for item in json_data: | |
| # imgs = item.get('image', []) | |
| # if isinstance(imgs, str): | |
| # imgs = [imgs] | |
| # keep = True | |
| # for img in imgs: | |
| # if any(pref in img for pref in unwanted_prefixes): | |
| # keep = False | |
| # break | |
| # if keep: | |
| # filtered.append(item) | |
| # json_data = filtered | |
| # print_log(f'Filtered out {original_count - len(json_data)} samples.', logger='current') | |
| # # ---- debug: include_ids 优先过滤 | |
| # if self._dbg_include_ids: | |
| # keep = [it for it in json_data if str(it.get('id')) in self._dbg_include_ids] | |
| # print_log(f'[DEBUG] include_ids -> keep {len(keep)}/{len(json_data)}', logger='current') | |
| # json_data = keep | |
| # # ---- debug: 子集抽样 | |
| # json_data = self._apply_debug_subset_to_list(json_data) | |
| # # id -> str | |
| # for idx in range(len(json_data)): | |
| # if isinstance(json_data[idx].get('id'), int): | |
| # json_data[idx]['id'] = str(json_data[idx]['id']) | |
| # # HF map & template | |
| # json_data = DatasetDict({'train': HFDataset.from_list(json_data)}) | |
| # self.text_data = process_hf_dataset( | |
| # dataset=json_data, | |
| # tokenizer=tokenizer, | |
| # max_length=max_length, | |
| # dataset_map_fn=dataset_map_fn, | |
| # template_map_fn=template_map_fn, | |
| # split='train', | |
| # max_dataset_length=max_dataset_length, | |
| # remove_unused_columns=False, | |
| # pack_to_max_length=False, | |
| # with_image_token=True, | |
| # per_image_length=self.per_image_length) | |
| # # ---------------------- image feature suffix sanity ---------------------- | |
| # if image_feature_suffix not in ['.csv', '.pt', '.h5']: | |
| # raise ValueError( | |
| # f'Unsupported image feature suffix: {image_feature_suffix}. ' | |
| # 'Supported suffixes are: .csv, .pt, .h5') | |
| # self.image_feature_suffix = image_feature_suffix | |
| # self.image_folder = image_folder | |
| # self.image_path_list = image_path_list | |
| # # ---------------------- shared-backed properties ---------------------- | |
| # @property | |
| # def sample_num(self) -> int: | |
| # return int(self._sample_num_shm.item()) | |
| # @sample_num.setter | |
| # def sample_num(self, v: int): | |
| # self._sample_num_shm.fill_(int(v)) | |
| # @property | |
| # def per_image_length(self) -> int: | |
| # return int(self._pil_shm.item()) | |
| # @per_image_length.setter | |
| # def per_image_length(self, v: int): | |
| # self._pil_shm.fill_(int(v)) | |
| # @property | |
| # def sample_strategy(self) -> str: | |
| # return _ID2STRATEGY[int(self._strategy_shm.item())] | |
| # @sample_strategy.setter | |
| # def sample_strategy(self, v: str): | |
| # if v not in _STRATEGY2ID: | |
| # raise ValueError(f"Unknown sample_strategy: {v}") | |
| # self._strategy_shm.fill_(_STRATEGY2ID[v]) | |
| # # ---------------------- helpers ---------------------- | |
| # def _load_unwanted_prefixes(self, csv_path): | |
| # unwanted_prefixes = set() | |
| # if csv_path and os.path.exists(csv_path): | |
| # print_log(f'Loading unwanted prefixes from: {csv_path}', logger='current') | |
| # try: | |
| # df = pd.read_csv(csv_path) | |
| # unwanted_prefixes = set(df.iloc[:, 0].astype(str).tolist()) | |
| # print_log(f'Loaded {len(unwanted_prefixes)} prefixes to filter out.', logger='current') | |
| # except Exception as e: | |
| # print_log(f'Could not read CSV file {csv_path}. Error: {e}', | |
| # logger='current', level=logging.ERROR) | |
| # print_log('Falling back to hardcoded list.', logger='current', level=logging.WARNING) | |
| # if not unwanted_prefixes: | |
| # print_log('Using hardcoded unwanted prefix list.', logger='current', level=logging.WARNING) | |
| # unwanted_prefixes = { | |
| # "TCGA-HT-7476-01Z-00-DX2", "TCGA-44-7661-01Z-00-DX1", "TCGA-DB-A64V-01Z-00-DX1", | |
| # "TCGA-CS-4938-01Z-00-DX1", "TCGA-DB-5273-01Z-00-DX2", "TCGA-DB-5278-01Z-00-DX1", | |
| # "TCGA-DB-A4XA-01Z-00-DX1", "TCGA-DB-A4XB-01Z-00-DX1", "TCGA-DB-A4XC-01Z-00-DX2", | |
| # "TCGA-DU-5849-01Z-00-DX1", "TCGA-DU-6399-01Z-00-DX1", "TCGA-DU-7006-01Z-00-DX1", | |
| # "TCGA-DU-7013-01Z-00-DX1", "TCGA-DU-8165-01Z-00-DX1", "TCGA-DU-A76O-01Z-00-DX1", | |
| # "TCGA-DU-A7TG-01Z-00-DX1", "TCGA-E1-A7YM-01Z-00-DX1", "TCGA-E1-A7Z6-01Z-00-DX1", | |
| # "TCGA-FG-A6J3-01Z-00-DX2", "TCGA-HT-7467-01Z-00-DX2", "TCGA-HT-7468-01Z-00-DX6", | |
| # "TCGA-HT-7470-01Z-00-DX4", "TCGA-HT-7470-01Z-00-DX9", "TCGA-HT-7473-01Z-00-DX2", | |
| # "TCGA-HT-7475-01Z-00-DX5", "TCGA-HT-7481-01Z-00-DX1", "TCGA-HT-7482-01Z-00-DX6", | |
| # "TCGA-HT-7601-01Z-00-DX3", "TCGA-HT-7607-01Z-00-DX10", "TCGA-HT-7608-01Z-00-DX2", | |
| # "TCGA-HT-7616-01Z-00-DX1", "TCGA-HT-7684-01Z-00-DX2", "TCGA-HT-7689-01Z-00-DX1", | |
| # "TCGA-HT-7690-01Z-00-DX4", "TCGA-HT-7855-01Z-00-DX1", "TCGA-HT-7856-01Z-00-DX6", | |
| # "TCGA-HT-7874-01Z-00-DX2", "TCGA-HT-8105-01Z-00-DX1", "TCGA-HT-8108-01Z-00-DX1", | |
| # "TCGA-HT-A74O-01Z-00-DX1", "TCGA-IK-8125-01Z-00-DX1", "TCGA-P5-A72X-01Z-00-DX1", | |
| # "TCGA-QH-A65R-01Z-00-DX1", "TCGA-QH-A870-01Z-00-DX1", "TCGA-R8-A6MO-01Z-00-DX7", | |
| # "TCGA-S9-A6TX-01Z-00-DX1", "TCGA-TM-A84I-01Z-00-DX1", "TCGA-TM-A84L-01Z-00-DX1", | |
| # "TCGA-TM-A84O-01Z-00-DX1", "TCGA-TQ-A7RP-01Z-00-DX1", "TCGA-VM-A8C8-01Z-00-DX8", | |
| # "TCGA-VM-A8C9-01Z-00-DX9", "TCGA-VM-A8CA-01Z-00-DX4", "TCGA-VM-A8CB-01Z-00-DX4", | |
| # "TCGA-VM-A8CB-01Z-00-DX5", "TCGA-VM-A8CD-01Z-00-DX6", "TCGA-VM-A8CE-01Z-00-DX1", | |
| # "TCGA-VM-A8CE-01Z-00-DX7", "TCGA-QK-A8ZB-01Z-00-DX1" | |
| # } | |
| # return unwanted_prefixes | |
| # def _load_jsonl(self, json_file): | |
| # with open(json_file) as f: | |
| # return [json.loads(line) for line in f] | |
| # def _apply_debug_subset_to_list(self, items): | |
| # if not items: | |
| # return items | |
| # n_before = len(items) | |
| # if self._dbg_include_ids: | |
| # items = [it for it in items if str(it.get('id')) in self._dbg_include_ids] | |
| # n_before = len(items) | |
| # print_log(f'[DEBUG] include_ids -> keep {n_before}', logger='current') | |
| # if self._dbg_max is None and self._dbg_ratio is not None: | |
| # self._dbg_max = max(1, int(round(n_before * float(self._dbg_ratio)))) | |
| # if self._dbg_max is None: | |
| # print_log('[DEBUG] dataset full size used.', logger='current') | |
| # return items | |
| # k = min(int(self._dbg_max), n_before) | |
| # if k <= 0: | |
| # return items | |
| # if self._dbg_shuffle: | |
| # rng = np.random.default_rng(self._dbg_seed) | |
| # idx = rng.choice(n_before, size=k, replace=False) | |
| # idx = sorted(idx.tolist()) | |
| # items = [items[i] for i in idx] | |
| # else: | |
| # items = items[:k] | |
| # print_log(f'[DEBUG] subset: {len(items)}/{n_before} samples used ' | |
| # f'({"random" if self._dbg_shuffle else "head"}).', | |
| # logger='current') | |
| # return items | |
| # def _apply_debug_subset_to_hf(self, ds: HFDataset) -> HFDataset: | |
| # n_before = ds.num_rows | |
| # if self._dbg_include_ids: | |
| # keep_idx = [i for i, ex in enumerate(ds) if str(ex.get('id')) in self._dbg_include_ids] | |
| # ds = ds.select(keep_idx) | |
| # print_log(f'[DEBUG] include_ids -> keep {ds.num_rows}/{n_before}', logger='current') | |
| # n_before = ds.num_rows | |
| # if self._dbg_max is None and self._dbg_ratio is not None: | |
| # self._dbg_max = max(1, int(round(n_before * float(self._dbg_ratio)))) | |
| # if self._dbg_max is None: | |
| # print_log('[DEBUG] dataset full size used (offline).', logger='current') | |
| # return ds | |
| # k = min(int(self._dbg_max), n_before) | |
| # if k <= 0: | |
| # return ds | |
| # if self._dbg_shuffle: | |
| # rng = np.random.default_rng(self._dbg_seed) | |
| # idx = rng.choice(n_before, size=k, replace=False) | |
| # idx = sorted(idx.tolist()) | |
| # else: | |
| # idx = list(range(k)) | |
| # ds = ds.select(idx) | |
| # print_log(f'[DEBUG] subset (offline): {ds.num_rows}/{n_before} samples used ' | |
| # f'({"random" if self._dbg_shuffle else "head"}).', | |
| # logger='current') | |
| # return ds | |
| # # -------- 每个 worker 的 RNG,保证可复现 -------- | |
| # def _rng(self): | |
| # """Return a numpy Generator seeded per-worker for reproducibility.""" | |
| # wi = get_worker_info() | |
| # base = self._dbg_seed | |
| # if wi is None: | |
| # seed = (base ^ (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF | |
| # else: | |
| # seed = (base + wi.id + (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF | |
| # return np.random.default_rng(seed) | |
| # # -------- 路径解析 -------- | |
| # def _parse_stub(self, image_path: str): | |
| # norm = os.path.normpath(image_path) | |
| # parts = norm.split(os.sep) | |
| # if len(parts) < 2: | |
| # fname = os.path.splitext(parts[-1])[0] | |
| # tumor = fname.split('-')[0].lower() if '-' in fname else 'unknown' | |
| # case = fname | |
| # else: | |
| # tumor = parts[-2].lower() | |
| # case = os.path.splitext(parts[-1])[0] | |
| # return tumor, case | |
| # # -------- 构造特征路径 -------- | |
| # def _build_feature_path(self, tumor_name: str, case_name: str): | |
| # if self.image_feature_suffix == ".pt": | |
| # subdir = "pt_files" | |
| # elif self.image_feature_suffix == ".csv": | |
| # subdir = "csv_files" | |
| # elif self.image_feature_suffix == ".h5": | |
| # subdir = "h5_files" | |
| # else: | |
| # raise ValueError(f"Unknown feature suffix: {self.image_feature_suffix}") | |
| # return os.path.join( | |
| # self.image_feature_prefix, | |
| # f"{tumor_name}{self.identifier}", | |
| # subdir, | |
| # case_name + self.image_feature_suffix | |
| # ) | |
| # # -------- 选择 patch 索引(支持 linspace / random / random_full) -------- | |
| # def _choose_indices(self, total_rows: int, rng: np.random.Generator): | |
| # k = self.sample_num | |
| # if total_rows <= 0: | |
| # return np.array([], dtype=int) | |
| # strat = self.sample_strategy | |
| # if strat == "random_full": | |
| # # 总是返回正好 k 个;不足则有放回 | |
| # replace = total_rows < k | |
| # idx = rng.choice(total_rows, size=k, replace=replace) | |
| # return np.sort(idx.astype(int)) | |
| # if strat == "random": | |
| # # 无放回随机;不足则直接用全部(返回 < k 个) | |
| # if total_rows <= k: | |
| # return np.arange(total_rows, dtype=int) | |
| # idx = rng.choice(total_rows, size=k, replace=False) | |
| # return np.sort(idx.astype(int)) | |
| # # 默认:等距 + 抖动;不足则直接全取 | |
| # if total_rows <= k: | |
| # return np.arange(total_rows, dtype=int) | |
| # step = total_rows / k | |
| # jitter = int(rng.integers(0, max(1, int(step)))) | |
| # indices = (np.floor(np.arange(k) * step + jitter)).astype(int) | |
| # return np.clip(indices, 0, total_rows - 1) | |
| # # ---------------------- rest of class ---------------------- | |
| # @property | |
| # def modality_length(self): | |
| # length_list = [] | |
| # for data_dict in self.text_data: | |
| # cur_len = len(data_dict['input_ids']) | |
| # image = data_dict.get('image', None) | |
| # if image is None: | |
| # cur_len = -cur_len | |
| # else: | |
| # n_images = 1 if isinstance(image, str) else len(image) | |
| # cur_len = cur_len - n_images + self.per_image_length * n_images | |
| # length_list.append(cur_len) | |
| # return length_list | |
| # def __len__(self): | |
| # return len(self.text_data) | |
| # def __getitem__(self, index): | |
| # wi = get_worker_info() | |
| # if not hasattr(self, "_printed_once"): | |
| # print_log( | |
| # f"[LLaVADataset] worker={wi.id if wi else -1} " | |
| # f"effective_k={self.sample_num} strategy={self.sample_strategy}", | |
| # logger="current" | |
| # ) | |
| # self._printed_once = True | |
| # data_dict = self.text_data[index] | |
| # if data_dict.get('image', None) is None: | |
| # return data_dict | |
| # image_list = data_dict['image'] | |
| # if isinstance(image_list, str): | |
| # image_list = [image_list] | |
| # images, coords_list = [], [] | |
| # rng = self._rng() | |
| # for image_file in image_list: | |
| # tumor_name, case_name = self._parse_stub(image_file) | |
| # train_image_file = self._build_feature_path(tumor_name, case_name) | |
| # if train_image_file.endswith('.csv'): | |
| # if not os.path.exists(train_image_file): | |
| # raise FileNotFoundError(train_image_file) | |
| # feats_df = pd.read_csv(train_image_file, usecols=range(512), dtype=np.float32) | |
| # total_rows = len(feats_df) | |
| # idx = self._choose_indices(total_rows, rng) | |
| # feats = torch.from_numpy(feats_df.to_numpy()[idx]).float() | |
| # images.append(feats) | |
| # coords_list.append(None) | |
| # elif train_image_file.endswith('.pt'): | |
| # if not os.path.exists(train_image_file): | |
| # raise FileNotFoundError(train_image_file) | |
| # feats_np = torch.load(train_image_file, map_location='cpu') | |
| # if isinstance(feats_np, torch.Tensor): | |
| # feats_np = feats_np.cpu().numpy() | |
| # feats_np = feats_np.astype(np.float32, copy=False) | |
| # total_rows = feats_np.shape[0] | |
| # idx = self._choose_indices(total_rows, rng) | |
| # feats = torch.from_numpy(feats_np[idx]).float() | |
| # images.append(feats) | |
| # coords_list.append(None) | |
| # elif train_image_file.endswith('.h5'): | |
| # if not os.path.exists(train_image_file): | |
| # raise FileNotFoundError(train_image_file) | |
| # with h5py.File(train_image_file, 'r') as f: | |
| # feats_np = f['features'][:] | |
| # coords_np = f['coords'][:] | |
| # if feats_np.shape[0] != coords_np.shape[0]: | |
| # raise ValueError( | |
| # f"Mismatch rows in features ({feats_np.shape[0]}) vs coords ({coords_np.shape[0]}) " | |
| # f"for {train_image_file}") | |
| # feats_np = feats_np.astype(np.float32, copy=False) | |
| # total_rows = feats_np.shape[0] | |
| # idx = self._choose_indices(total_rows, rng) | |
| # feats = torch.from_numpy(feats_np[idx]).float() | |
| # coords = torch.from_numpy(coords_np[idx]).long() | |
| # images.append(feats) | |
| # coords_list.append(coords) | |
| # else: | |
| # raise ValueError(f'Unsupported file: {train_image_file}') | |
| # data_dict['pixel_values'] = images | |
| # if any(c is not None for c in coords_list): | |
| # coords_list = [c if c is not None else torch.empty(0, 2, dtype=torch.long) | |
| # for c in coords_list] | |
| # data_dict['coords'] = coords_list | |
| # return data_dict | |
| from __future__ import annotations | |
| import json, logging, os | |
| import numpy as np | |
| import pandas as pd | |
| import torch, h5py | |
| from datasets import Dataset as HFDataset | |
| from datasets import DatasetDict, load_from_disk | |
| from mmengine import print_log | |
| from torch.utils.data import Dataset, get_worker_info | |
| from xtuner.registry import BUILDER | |
| from .huggingface import process_hf_dataset | |
| class LLaVADataset(Dataset): | |
| def __init__(self, | |
| image_folder, | |
| image_path_list, | |
| per_image_length, | |
| data_path=None, | |
| tokenizer=None, | |
| offline_processed_text_folder=None, | |
| max_dataset_length=None, | |
| dataset_map_fn=None, | |
| template_map_fn=None, | |
| max_length=2048, | |
| pad_image_to_square=False, | |
| sample_num=10240, | |
| image_feature_prefix='', | |
| identifier='', | |
| image_feature_suffix='.pt', | |
| unwanted_prefix_csv=None, | |
| sample_strategy='linspace', | |
| # ---------- DEBUG ---------- | |
| debug_max_samples=None, | |
| debug_ratio=None, | |
| debug_shuffle=True, | |
| debug_seed=3407, | |
| debug_include_ids=None): | |
| super().__init__() | |
| self.sample_num = int(sample_num) | |
| self.per_image_length = int(per_image_length) | |
| self.pad_image_to_square = pad_image_to_square | |
| self.image_feature_prefix = image_feature_prefix | |
| self.identifier = identifier | |
| self.sample_strategy = sample_strategy # 'linspace' | 'random' | 'random_full' | |
| # debug opts | |
| self._dbg_max = debug_max_samples | |
| self._dbg_ratio = debug_ratio | |
| self._dbg_shuffle = debug_shuffle | |
| self._dbg_seed = int(debug_seed) | |
| self._dbg_include_ids = set(debug_include_ids) if debug_include_ids else None | |
| assert offline_processed_text_folder or (data_path and tokenizer) | |
| if offline_processed_text_folder and data_path: | |
| print_log( | |
| 'Both `offline_processed_text_folder` and `data_path` are set, ' | |
| 'and we load dataset from `offline_processed_text_folder` ' | |
| f'({offline_processed_text_folder})', | |
| logger='current', level=logging.WARNING) | |
| # ---------------------- load text ---------------------- | |
| if offline_processed_text_folder is not None: | |
| ds = load_from_disk(offline_processed_text_folder) | |
| if isinstance(ds, DatasetDict): | |
| ds = ds.get('train', None) or next(iter(ds.values())) | |
| assert isinstance(ds, HFDataset) | |
| text_ds = ds | |
| text_ds = self._apply_debug_subset_to_hf(text_ds) | |
| self.text_data = text_ds | |
| else: | |
| if data_path.endswith('.json'): | |
| json_data = json.load(open(data_path)) | |
| elif data_path.endswith('.jsonl'): | |
| json_data = self._load_jsonl(data_path) | |
| else: | |
| raise NotImplementedError | |
| # ---- unwanted prefixes | |
| unwanted_prefixes = self._load_unwanted_prefixes(unwanted_prefix_csv) | |
| original_count = len(json_data) | |
| filtered = [] | |
| for item in json_data: | |
| imgs = item.get('image', []) | |
| if isinstance(imgs, str): | |
| imgs = [imgs] | |
| keep = True | |
| for img in imgs: | |
| if any(pref in img for pref in unwanted_prefixes): | |
| keep = False | |
| break | |
| if keep: | |
| filtered.append(item) | |
| json_data = filtered | |
| print_log(f'Filtered out {original_count - len(json_data)} samples.', logger='current') | |
| # ---- debug include_ids | |
| if self._dbg_include_ids: | |
| keep = [it for it in json_data if str(it.get('id')) in self._dbg_include_ids] | |
| print_log(f'[DEBUG] include_ids -> keep {len(keep)}/{len(json_data)}', logger='current') | |
| json_data = keep | |
| # ---- debug subset | |
| json_data = self._apply_debug_subset_to_list(json_data) | |
| # id -> str | |
| for idx in range(len(json_data)): | |
| if isinstance(json_data[idx].get('id'), int): | |
| json_data[idx]['id'] = str(json_data[idx]['id']) | |
| # HF map & template | |
| json_data = DatasetDict({'train': HFDataset.from_list(json_data)}) | |
| self.text_data = process_hf_dataset( | |
| dataset=json_data, | |
| tokenizer=tokenizer, | |
| max_length=max_length, | |
| dataset_map_fn=dataset_map_fn, | |
| template_map_fn=template_map_fn, | |
| split='train', | |
| max_dataset_length=max_dataset_length, | |
| remove_unused_columns=False, | |
| pack_to_max_length=False, | |
| with_image_token=True, | |
| per_image_length=self.per_image_length) | |
| # ---------------------- image feature suffix sanity ---------------------- | |
| if image_feature_suffix not in ['.csv', '.pt', '.h5']: | |
| raise ValueError( | |
| f'Unsupported image feature suffix: {image_feature_suffix}. ' | |
| 'Supported suffixes are: .csv, .pt, .h5') | |
| self.image_feature_suffix = image_feature_suffix | |
| self.image_folder = image_folder | |
| self.image_path_list = image_path_list | |
| # ---------------------- helpers ---------------------- | |
| def _load_unwanted_prefixes(self, csv_path): | |
| unwanted_prefixes = set() | |
| if csv_path and os.path.exists(csv_path): | |
| print_log(f'Loading unwanted prefixes from: {csv_path}', logger='current') | |
| try: | |
| df = pd.read_csv(csv_path) | |
| unwanted_prefixes = set(df.iloc[:, 0].astype(str).tolist()) | |
| print_log(f'Loaded {len(unwanted_prefixes)} prefixes to filter out.', logger='current') | |
| except Exception as e: | |
| print_log(f'Could not read CSV file {csv_path}. Error: {e}', | |
| logger='current', level=logging.ERROR) | |
| print_log('Falling back to hardcoded list.', logger='current', level=logging.WARNING) | |
| if not unwanted_prefixes: | |
| print_log('Using hardcoded unwanted prefix list.', logger='current', level=logging.WARNING) | |
| unwanted_prefixes = { | |
| "TCGA-HT-7476-01Z-00-DX2", "TCGA-44-7661-01Z-00-DX1", "TCGA-DB-A64V-01Z-00-DX1", | |
| "TCGA-CS-4938-01Z-00-DX1", "TCGA-DB-5273-01Z-00-DX2", "TCGA-DB-5278-01Z-00-DX1", | |
| "TCGA-DB-A4XA-01Z-00-DX1", "TCGA-DB-A4XB-01Z-00-DX1", "TCGA-DB-A4XC-01Z-00-DX2", | |
| "TCGA-DU-5849-01Z-00-DX1", "TCGA-DU-6399-01Z-00-DX1", "TCGA-DU-7006-01Z-00-DX1", | |
| "TCGA-DU-7013-01Z-00-DX1", "TCGA-DU-8165-01Z-00-DX1", "TCGA-DU-A76O-01Z-00-DX1", | |
| "TCGA-DU-A7TG-01Z-00-DX1", "TCGA-E1-A7YM-01Z-00-DX1", "TCGA-E1-A7Z6-01Z-00-DX1", | |
| "TCGA-FG-A6J3-01Z-00-DX2", "TCGA-HT-7467-01Z-00-DX2", "TCGA-HT-7468-01Z-00-DX6", | |
| "TCGA-HT-7470-01Z-00-DX4", "TCGA-HT-7470-01Z-00-DX9", "TCGA-HT-7473-01Z-00-DX2", | |
| "TCGA-HT-7475-01Z-00-DX5", "TCGA-HT-7481-01Z-00-DX1", "TCGA-HT-7482-01Z-00-DX6", | |
| "TCGA-HT-7601-01Z-00-DX3", "TCGA-HT-7607-01Z-00-DX10", "TCGA-HT-7608-01Z-00-DX2", | |
| "TCGA-HT-7616-01Z-00-DX1", "TCGA-HT-7684-01Z-00-DX2", "TCGA-HT-7689-01Z-00-DX1", | |
| "TCGA-HT-7690-01Z-00-DX4", "TCGA-HT-7855-01Z-00-DX1", "TCGA-HT-7856-01Z-00-DX6", | |
| "TCGA-HT-7874-01Z-00-DX2", "TCGA-HT-8105-01Z-00-DX1", "TCGA-HT-8108-01Z-00-DX1", | |
| "TCGA-HT-A74O-01Z-00-DX1", "TCGA-IK-8125-01Z-00-DX1", "TCGA-P5-A72X-01Z-00-DX1", | |
| "TCGA-QH-A65R-01Z-00-DX1", "TCGA-QH-A870-01Z-00-DX1", "TCGA-R8-A6MO-01Z-00-DX7", | |
| "TCGA-S9-A6TX-01Z-00-DX1", "TCGA-TM-A84I-01Z-00-DX1", "TCGA-TM-A84L-01Z-00-DX1", | |
| "TCGA-TM-A84O-01Z-00-DX1", "TCGA-TQ-A7RP-01Z-00-DX1", "TCGA-VM-A8C8-01Z-00-DX8", | |
| "TCGA-VM-A8C9-01Z-00-DX9", "TCGA-VM-A8CA-01Z-00-DX4", "TCGA-VM-A8CB-01Z-00-DX4", | |
| "TCGA-VM-A8CB-01Z-00-DX5", "TCGA-VM-A8CD-01Z-00-DX6", "TCGA-VM-A8CE-01Z-00-DX1", | |
| "TCGA-VM-A8CE-01Z-00-DX7", "TCGA-QK-A8ZB-01Z-00-DX1" | |
| } | |
| return unwanted_prefixes | |
| def _load_jsonl(self, json_file): | |
| with open(json_file) as f: | |
| return [json.loads(line) for line in f] | |
| def _apply_debug_subset_to_list(self, items): | |
| if not items: | |
| return items | |
| n_before = len(items) | |
| if self._dbg_include_ids: | |
| items = [it for it in items if str(it.get('id')) in self._dbg_include_ids] | |
| n_before = len(items) | |
| print_log(f'[DEBUG] include_ids -> keep {n_before}', logger='current') | |
| if self._dbg_max is None and self._dbg_ratio is not None: | |
| self._dbg_max = max(1, int(round(n_before * float(self._dbg_ratio)))) | |
| if self._dbg_max is None: | |
| print_log('[DEBUG] dataset full size used.', logger='current') | |
| return items | |
| k = min(int(self._dbg_max), n_before) | |
| if k <= 0: | |
| return items | |
| if self._dbg_shuffle: | |
| rng = np.random.default_rng(self._dbg_seed) | |
| idx = rng.choice(n_before, size=k, replace=False) | |
| idx = sorted(idx.tolist()) | |
| items = [items[i] for i in idx] | |
| else: | |
| items = items[:k] | |
| print_log(f'[DEBUG] subset: {len(items)}/{n_before} samples used ' | |
| f'({"random" if self._dbg_shuffle else "head"}).', | |
| logger='current') | |
| return items | |
| def _apply_debug_subset_to_hf(self, ds: HFDataset) -> HFDataset: | |
| n_before = ds.num_rows | |
| if self._dbg_include_ids: | |
| keep_idx = [i for i, ex in enumerate(ds) if str(ex.get('id')) in self._dbg_include_ids] | |
| ds = ds.select(keep_idx) | |
| print_log(f'[DEBUG] include_ids -> keep {ds.num_rows}/{n_before}', logger='current') | |
| n_before = ds.num_rows | |
| if self._dbg_max is None and self._dbg_ratio is not None: | |
| self._dbg_max = max(1, int(round(n_before * float(self._dbg_ratio)))) | |
| if self._dbg_max is None: | |
| print_log('[DEBUG] dataset full size used (offline).', logger='current') | |
| return ds | |
| k = min(int(self._dbg_max), n_before) | |
| if k <= 0: | |
| return ds | |
| if self._dbg_shuffle: | |
| rng = np.random.default_rng(self._dbg_seed) | |
| idx = rng.choice(n_before, size=k, replace=False) | |
| idx = sorted(idx.tolist()) | |
| else: | |
| idx = list(range(k)) | |
| ds = ds.select(idx) | |
| print_log(f'[DEBUG] subset (offline): {ds.num_rows}/{n_before} samples used ' | |
| f'({"random" if self._dbg_shuffle else "head"}).', | |
| logger='current') | |
| return ds | |
| # -------- per-worker RNG -------- | |
| def _rng(self): | |
| wi = get_worker_info() | |
| base = self._dbg_seed | |
| if wi is None: | |
| seed = (base ^ (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF | |
| else: | |
| seed = (base + wi.id + (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF | |
| return np.random.default_rng(seed) | |
| # -------- path parsing -------- | |
| def _parse_stub(self, image_path: str): | |
| norm = os.path.normpath(image_path) | |
| parts = norm.split(os.sep) | |
| if len(parts) < 2: | |
| fname = os.path.splitext(parts[-1])[0] | |
| tumor = fname.split('-')[0].lower() if '-' in fname else 'unknown' | |
| case = fname | |
| else: | |
| tumor = parts[-2].lower() | |
| case = os.path.splitext(parts[-1])[0] | |
| return tumor, case | |
| def _build_feature_path(self, tumor_name: str, case_name: str): | |
| if self.image_feature_suffix == ".pt": | |
| subdir = "pt_files" | |
| elif self.image_feature_suffix == ".csv": | |
| subdir = "csv_files" | |
| elif self.image_feature_suffix == ".h5": | |
| subdir = "h5_files" | |
| else: | |
| raise ValueError(f"Unknown feature suffix: {self.image_feature_suffix}") | |
| return os.path.join( | |
| self.image_feature_prefix, | |
| f"{tumor_name}{self.identifier}", | |
| subdir, | |
| case_name + self.image_feature_suffix | |
| ) | |
| # -------- choose patch indices -------- | |
| def _choose_indices(self, total_rows: int, rng: np.random.Generator): | |
| k = self.sample_num | |
| if total_rows <= 0: | |
| return np.array([], dtype=int) | |
| if self.sample_strategy == "random_full": | |
| # Always exactly k rows; with replacement if needed | |
| replace = total_rows < k | |
| idx = rng.choice(total_rows, size=k, replace=replace) | |
| return np.sort(idx.astype(int)) | |
| if self.sample_strategy == "random": | |
| if total_rows <= k: | |
| return np.arange(total_rows, dtype=int) | |
| idx = rng.choice(total_rows, size=k, replace=False) | |
| return np.sort(idx.astype(int)) | |
| # linspace | |
| if total_rows <= k: | |
| return np.arange(total_rows, dtype=int) | |
| step = total_rows / k | |
| jitter = int(rng.integers(0, max(1, int(step)))) | |
| indices = (np.floor(np.arange(k) * step + jitter)).astype(int) | |
| return np.clip(indices, 0, total_rows - 1) | |
| # ---------------------- rest of class ---------------------- | |
| def modality_length(self): | |
| length_list = [] | |
| for data_dict in self.text_data: | |
| cur_len = len(data_dict['input_ids']) | |
| image = data_dict.get('image', None) | |
| if image is None: | |
| cur_len = -cur_len | |
| else: | |
| n_images = 1 if isinstance(image, str) else len(image) | |
| cur_len = cur_len - n_images + self.per_image_length * n_images | |
| length_list.append(cur_len) | |
| return length_list | |
| def __len__(self): | |
| return len(self.text_data) | |
| def __getitem__(self, index): | |
| data_dict = self.text_data[index] | |
| if data_dict.get('image', None) is None: | |
| return data_dict | |
| image_list = data_dict['image'] | |
| if isinstance(image_list, str): | |
| image_list = [image_list] | |
| images, coords_list = [], [] | |
| rng = self._rng() | |
| for image_file in image_list: | |
| tumor_name, case_name = self._parse_stub(image_file) | |
| train_image_file = self._build_feature_path(tumor_name, case_name) | |
| if train_image_file.endswith('.csv'): | |
| if not os.path.exists(train_image_file): | |
| raise FileNotFoundError(train_image_file) | |
| feats_df = pd.read_csv(train_image_file, usecols=range(512), dtype=np.float32) | |
| total_rows = len(feats_df) | |
| idx = self._choose_indices(total_rows, rng) | |
| feats = torch.from_numpy(feats_df.to_numpy()[idx]).float() | |
| images.append(feats) | |
| coords_list.append(None) | |
| elif train_image_file.endswith('.pt'): | |
| if not os.path.exists(train_image_file): | |
| raise FileNotFoundError(train_image_file) | |
| feats_np = torch.load(train_image_file, map_location='cpu') | |
| if isinstance(feats_np, torch.Tensor): | |
| feats_np = feats_np.cpu().numpy() | |
| feats_np = feats_np.astype(np.float32, copy=False) | |
| total_rows = feats_np.shape[0] | |
| idx = self._choose_indices(total_rows, rng) | |
| feats = torch.from_numpy(feats_np[idx]).float() | |
| images.append(feats) | |
| coords_list.append(None) | |
| elif train_image_file.endswith('.h5'): | |
| if not os.path.exists(train_image_file): | |
| raise FileNotFoundError(train_image_file) | |
| with h5py.File(train_image_file, 'r') as f: | |
| feats_np = f['features'][:] | |
| coords_np = f['coords'][:] | |
| if feats_np.shape[0] != coords_np.shape[0]: | |
| raise ValueError( | |
| f"Mismatch rows in features ({feats_np.shape[0]}) vs coords ({coords_np.shape[0]}) " | |
| f"for {train_image_file}") | |
| feats_np = feats_np.astype(np.float32, copy=False) | |
| total_rows = feats_np.shape[0] | |
| idx = self._choose_indices(total_rows, rng) | |
| feats = torch.from_numpy(feats_np[idx]).float() | |
| coords = torch.from_numpy(coords_np[idx]).long() | |
| images.append(feats) | |
| coords_list.append(coords) | |
| else: | |
| raise ValueError(f'Unsupported file: {train_image_file}') | |
| data_dict['pixel_values'] = images | |
| if any(c is not None for c in coords_list): | |
| coords_list = [c if c is not None else torch.empty(0, 2, dtype=torch.long) | |
| for c in coords_list] | |
| data_dict['coords'] = coords_list | |
| return data_dict |