WinstonHu's picture
Upload folder xtuner to code/xtuner
e5e24c9 verified
raw
history blame
37.5 kB
# import json
# import logging
# import os
# import numpy as np
# import pandas as pd
# import torch
# import h5py
# from datasets import Dataset as HFDataset
# from datasets import DatasetDict, load_from_disk
# from mmengine import print_log
# from torch.utils.data import Dataset, get_worker_info
# from xtuner.registry import BUILDER
# from .huggingface import process_hf_dataset
# # 映射采样策略到共享整数,便于多进程同步
# _STRATEGY2ID = {"linspace": 0, "random": 1, "random_full": 2}
# _ID2STRATEGY = {v: k for k, v in _STRATEGY2ID.items()}
# class LLaVADataset(Dataset):
# def __init__(self,
# image_folder,
# image_path_list,
# per_image_length,
# data_path=None,
# tokenizer=None,
# offline_processed_text_folder=None,
# max_dataset_length=None,
# dataset_map_fn=None,
# template_map_fn=None,
# max_length=2048,
# pad_image_to_square=False,
# sample_num=10240,
# image_feature_prefix='',
# identifier='',
# image_feature_suffix='.pt',
# unwanted_prefix_csv=None,
# sample_strategy: str = 'linspace', # 新增:默认等距
# # ---------- DEBUG 选项 ----------
# debug_max_samples=None,
# debug_ratio=None,
# debug_shuffle=True,
# debug_seed=3407,
# debug_include_ids=None):
# super().__init__()
# # ---- 通过共享内存暴露可变控制量,确保 Hook 在主进程修改后,worker 可见 ----
# self._sample_num_shm = torch.tensor([int(sample_num)], dtype=torch.int32)
# self._sample_num_shm.share_memory_()
# self._pil_shm = torch.tensor([int(per_image_length)], dtype=torch.int32)
# self._pil_shm.share_memory_()
# if sample_strategy not in _STRATEGY2ID:
# raise ValueError(f"Unsupported sample_strategy: {sample_strategy}")
# self._strategy_shm = torch.tensor([_STRATEGY2ID[sample_strategy]], dtype=torch.int32)
# self._strategy_shm.share_memory_()
# self.pad_image_to_square = pad_image_to_square
# self.image_feature_prefix = image_feature_prefix
# self.identifier = identifier
# # debug opts
# self._dbg_max = debug_max_samples
# self._dbg_ratio = debug_ratio
# self._dbg_shuffle = debug_shuffle
# self._dbg_seed = int(debug_seed)
# self._dbg_include_ids = set(debug_include_ids) if debug_include_ids else None
# assert offline_processed_text_folder or (data_path and tokenizer)
# if offline_processed_text_folder and data_path:
# print_log(
# 'Both `offline_processed_text_folder` and `data_path` are set, '
# 'and we load dataset from `offline_processed_text_folder` '
# f'({offline_processed_text_folder})',
# logger='current', level=logging.WARNING)
# # ---------------------- load text ----------------------
# if offline_processed_text_folder is not None:
# ds = load_from_disk(offline_processed_text_folder)
# if isinstance(ds, DatasetDict):
# ds = ds.get('train', None) or next(iter(ds.values()))
# assert isinstance(ds, HFDataset)
# text_ds = ds
# text_ds = self._apply_debug_subset_to_hf(text_ds)
# self.text_data = text_ds
# else:
# if data_path.endswith('.json'):
# json_data = json.load(open(data_path))
# elif data_path.endswith('.jsonl'):
# json_data = self._load_jsonl(data_path)
# else:
# raise NotImplementedError
# # ---- filter out unwanted prefixes (string/list 都兼容)
# unwanted_prefixes = self._load_unwanted_prefixes(unwanted_prefix_csv)
# original_count = len(json_data)
# filtered = []
# for item in json_data:
# imgs = item.get('image', [])
# if isinstance(imgs, str):
# imgs = [imgs]
# keep = True
# for img in imgs:
# if any(pref in img for pref in unwanted_prefixes):
# keep = False
# break
# if keep:
# filtered.append(item)
# json_data = filtered
# print_log(f'Filtered out {original_count - len(json_data)} samples.', logger='current')
# # ---- debug: include_ids 优先过滤
# if self._dbg_include_ids:
# keep = [it for it in json_data if str(it.get('id')) in self._dbg_include_ids]
# print_log(f'[DEBUG] include_ids -> keep {len(keep)}/{len(json_data)}', logger='current')
# json_data = keep
# # ---- debug: 子集抽样
# json_data = self._apply_debug_subset_to_list(json_data)
# # id -> str
# for idx in range(len(json_data)):
# if isinstance(json_data[idx].get('id'), int):
# json_data[idx]['id'] = str(json_data[idx]['id'])
# # HF map & template
# json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
# self.text_data = process_hf_dataset(
# dataset=json_data,
# tokenizer=tokenizer,
# max_length=max_length,
# dataset_map_fn=dataset_map_fn,
# template_map_fn=template_map_fn,
# split='train',
# max_dataset_length=max_dataset_length,
# remove_unused_columns=False,
# pack_to_max_length=False,
# with_image_token=True,
# per_image_length=self.per_image_length)
# # ---------------------- image feature suffix sanity ----------------------
# if image_feature_suffix not in ['.csv', '.pt', '.h5']:
# raise ValueError(
# f'Unsupported image feature suffix: {image_feature_suffix}. '
# 'Supported suffixes are: .csv, .pt, .h5')
# self.image_feature_suffix = image_feature_suffix
# self.image_folder = image_folder
# self.image_path_list = image_path_list
# # ---------------------- shared-backed properties ----------------------
# @property
# def sample_num(self) -> int:
# return int(self._sample_num_shm.item())
# @sample_num.setter
# def sample_num(self, v: int):
# self._sample_num_shm.fill_(int(v))
# @property
# def per_image_length(self) -> int:
# return int(self._pil_shm.item())
# @per_image_length.setter
# def per_image_length(self, v: int):
# self._pil_shm.fill_(int(v))
# @property
# def sample_strategy(self) -> str:
# return _ID2STRATEGY[int(self._strategy_shm.item())]
# @sample_strategy.setter
# def sample_strategy(self, v: str):
# if v not in _STRATEGY2ID:
# raise ValueError(f"Unknown sample_strategy: {v}")
# self._strategy_shm.fill_(_STRATEGY2ID[v])
# # ---------------------- helpers ----------------------
# def _load_unwanted_prefixes(self, csv_path):
# unwanted_prefixes = set()
# if csv_path and os.path.exists(csv_path):
# print_log(f'Loading unwanted prefixes from: {csv_path}', logger='current')
# try:
# df = pd.read_csv(csv_path)
# unwanted_prefixes = set(df.iloc[:, 0].astype(str).tolist())
# print_log(f'Loaded {len(unwanted_prefixes)} prefixes to filter out.', logger='current')
# except Exception as e:
# print_log(f'Could not read CSV file {csv_path}. Error: {e}',
# logger='current', level=logging.ERROR)
# print_log('Falling back to hardcoded list.', logger='current', level=logging.WARNING)
# if not unwanted_prefixes:
# print_log('Using hardcoded unwanted prefix list.', logger='current', level=logging.WARNING)
# unwanted_prefixes = {
# "TCGA-HT-7476-01Z-00-DX2", "TCGA-44-7661-01Z-00-DX1", "TCGA-DB-A64V-01Z-00-DX1",
# "TCGA-CS-4938-01Z-00-DX1", "TCGA-DB-5273-01Z-00-DX2", "TCGA-DB-5278-01Z-00-DX1",
# "TCGA-DB-A4XA-01Z-00-DX1", "TCGA-DB-A4XB-01Z-00-DX1", "TCGA-DB-A4XC-01Z-00-DX2",
# "TCGA-DU-5849-01Z-00-DX1", "TCGA-DU-6399-01Z-00-DX1", "TCGA-DU-7006-01Z-00-DX1",
# "TCGA-DU-7013-01Z-00-DX1", "TCGA-DU-8165-01Z-00-DX1", "TCGA-DU-A76O-01Z-00-DX1",
# "TCGA-DU-A7TG-01Z-00-DX1", "TCGA-E1-A7YM-01Z-00-DX1", "TCGA-E1-A7Z6-01Z-00-DX1",
# "TCGA-FG-A6J3-01Z-00-DX2", "TCGA-HT-7467-01Z-00-DX2", "TCGA-HT-7468-01Z-00-DX6",
# "TCGA-HT-7470-01Z-00-DX4", "TCGA-HT-7470-01Z-00-DX9", "TCGA-HT-7473-01Z-00-DX2",
# "TCGA-HT-7475-01Z-00-DX5", "TCGA-HT-7481-01Z-00-DX1", "TCGA-HT-7482-01Z-00-DX6",
# "TCGA-HT-7601-01Z-00-DX3", "TCGA-HT-7607-01Z-00-DX10", "TCGA-HT-7608-01Z-00-DX2",
# "TCGA-HT-7616-01Z-00-DX1", "TCGA-HT-7684-01Z-00-DX2", "TCGA-HT-7689-01Z-00-DX1",
# "TCGA-HT-7690-01Z-00-DX4", "TCGA-HT-7855-01Z-00-DX1", "TCGA-HT-7856-01Z-00-DX6",
# "TCGA-HT-7874-01Z-00-DX2", "TCGA-HT-8105-01Z-00-DX1", "TCGA-HT-8108-01Z-00-DX1",
# "TCGA-HT-A74O-01Z-00-DX1", "TCGA-IK-8125-01Z-00-DX1", "TCGA-P5-A72X-01Z-00-DX1",
# "TCGA-QH-A65R-01Z-00-DX1", "TCGA-QH-A870-01Z-00-DX1", "TCGA-R8-A6MO-01Z-00-DX7",
# "TCGA-S9-A6TX-01Z-00-DX1", "TCGA-TM-A84I-01Z-00-DX1", "TCGA-TM-A84L-01Z-00-DX1",
# "TCGA-TM-A84O-01Z-00-DX1", "TCGA-TQ-A7RP-01Z-00-DX1", "TCGA-VM-A8C8-01Z-00-DX8",
# "TCGA-VM-A8C9-01Z-00-DX9", "TCGA-VM-A8CA-01Z-00-DX4", "TCGA-VM-A8CB-01Z-00-DX4",
# "TCGA-VM-A8CB-01Z-00-DX5", "TCGA-VM-A8CD-01Z-00-DX6", "TCGA-VM-A8CE-01Z-00-DX1",
# "TCGA-VM-A8CE-01Z-00-DX7", "TCGA-QK-A8ZB-01Z-00-DX1"
# }
# return unwanted_prefixes
# def _load_jsonl(self, json_file):
# with open(json_file) as f:
# return [json.loads(line) for line in f]
# def _apply_debug_subset_to_list(self, items):
# if not items:
# return items
# n_before = len(items)
# if self._dbg_include_ids:
# items = [it for it in items if str(it.get('id')) in self._dbg_include_ids]
# n_before = len(items)
# print_log(f'[DEBUG] include_ids -> keep {n_before}', logger='current')
# if self._dbg_max is None and self._dbg_ratio is not None:
# self._dbg_max = max(1, int(round(n_before * float(self._dbg_ratio))))
# if self._dbg_max is None:
# print_log('[DEBUG] dataset full size used.', logger='current')
# return items
# k = min(int(self._dbg_max), n_before)
# if k <= 0:
# return items
# if self._dbg_shuffle:
# rng = np.random.default_rng(self._dbg_seed)
# idx = rng.choice(n_before, size=k, replace=False)
# idx = sorted(idx.tolist())
# items = [items[i] for i in idx]
# else:
# items = items[:k]
# print_log(f'[DEBUG] subset: {len(items)}/{n_before} samples used '
# f'({"random" if self._dbg_shuffle else "head"}).',
# logger='current')
# return items
# def _apply_debug_subset_to_hf(self, ds: HFDataset) -> HFDataset:
# n_before = ds.num_rows
# if self._dbg_include_ids:
# keep_idx = [i for i, ex in enumerate(ds) if str(ex.get('id')) in self._dbg_include_ids]
# ds = ds.select(keep_idx)
# print_log(f'[DEBUG] include_ids -> keep {ds.num_rows}/{n_before}', logger='current')
# n_before = ds.num_rows
# if self._dbg_max is None and self._dbg_ratio is not None:
# self._dbg_max = max(1, int(round(n_before * float(self._dbg_ratio))))
# if self._dbg_max is None:
# print_log('[DEBUG] dataset full size used (offline).', logger='current')
# return ds
# k = min(int(self._dbg_max), n_before)
# if k <= 0:
# return ds
# if self._dbg_shuffle:
# rng = np.random.default_rng(self._dbg_seed)
# idx = rng.choice(n_before, size=k, replace=False)
# idx = sorted(idx.tolist())
# else:
# idx = list(range(k))
# ds = ds.select(idx)
# print_log(f'[DEBUG] subset (offline): {ds.num_rows}/{n_before} samples used '
# f'({"random" if self._dbg_shuffle else "head"}).',
# logger='current')
# return ds
# # -------- 每个 worker 的 RNG,保证可复现 --------
# def _rng(self):
# """Return a numpy Generator seeded per-worker for reproducibility."""
# wi = get_worker_info()
# base = self._dbg_seed
# if wi is None:
# seed = (base ^ (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF
# else:
# seed = (base + wi.id + (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF
# return np.random.default_rng(seed)
# # -------- 路径解析 --------
# def _parse_stub(self, image_path: str):
# norm = os.path.normpath(image_path)
# parts = norm.split(os.sep)
# if len(parts) < 2:
# fname = os.path.splitext(parts[-1])[0]
# tumor = fname.split('-')[0].lower() if '-' in fname else 'unknown'
# case = fname
# else:
# tumor = parts[-2].lower()
# case = os.path.splitext(parts[-1])[0]
# return tumor, case
# # -------- 构造特征路径 --------
# def _build_feature_path(self, tumor_name: str, case_name: str):
# if self.image_feature_suffix == ".pt":
# subdir = "pt_files"
# elif self.image_feature_suffix == ".csv":
# subdir = "csv_files"
# elif self.image_feature_suffix == ".h5":
# subdir = "h5_files"
# else:
# raise ValueError(f"Unknown feature suffix: {self.image_feature_suffix}")
# return os.path.join(
# self.image_feature_prefix,
# f"{tumor_name}{self.identifier}",
# subdir,
# case_name + self.image_feature_suffix
# )
# # -------- 选择 patch 索引(支持 linspace / random / random_full) --------
# def _choose_indices(self, total_rows: int, rng: np.random.Generator):
# k = self.sample_num
# if total_rows <= 0:
# return np.array([], dtype=int)
# strat = self.sample_strategy
# if strat == "random_full":
# # 总是返回正好 k 个;不足则有放回
# replace = total_rows < k
# idx = rng.choice(total_rows, size=k, replace=replace)
# return np.sort(idx.astype(int))
# if strat == "random":
# # 无放回随机;不足则直接用全部(返回 < k 个)
# if total_rows <= k:
# return np.arange(total_rows, dtype=int)
# idx = rng.choice(total_rows, size=k, replace=False)
# return np.sort(idx.astype(int))
# # 默认:等距 + 抖动;不足则直接全取
# if total_rows <= k:
# return np.arange(total_rows, dtype=int)
# step = total_rows / k
# jitter = int(rng.integers(0, max(1, int(step))))
# indices = (np.floor(np.arange(k) * step + jitter)).astype(int)
# return np.clip(indices, 0, total_rows - 1)
# # ---------------------- rest of class ----------------------
# @property
# def modality_length(self):
# length_list = []
# for data_dict in self.text_data:
# cur_len = len(data_dict['input_ids'])
# image = data_dict.get('image', None)
# if image is None:
# cur_len = -cur_len
# else:
# n_images = 1 if isinstance(image, str) else len(image)
# cur_len = cur_len - n_images + self.per_image_length * n_images
# length_list.append(cur_len)
# return length_list
# def __len__(self):
# return len(self.text_data)
# def __getitem__(self, index):
# wi = get_worker_info()
# if not hasattr(self, "_printed_once"):
# print_log(
# f"[LLaVADataset] worker={wi.id if wi else -1} "
# f"effective_k={self.sample_num} strategy={self.sample_strategy}",
# logger="current"
# )
# self._printed_once = True
# data_dict = self.text_data[index]
# if data_dict.get('image', None) is None:
# return data_dict
# image_list = data_dict['image']
# if isinstance(image_list, str):
# image_list = [image_list]
# images, coords_list = [], []
# rng = self._rng()
# for image_file in image_list:
# tumor_name, case_name = self._parse_stub(image_file)
# train_image_file = self._build_feature_path(tumor_name, case_name)
# if train_image_file.endswith('.csv'):
# if not os.path.exists(train_image_file):
# raise FileNotFoundError(train_image_file)
# feats_df = pd.read_csv(train_image_file, usecols=range(512), dtype=np.float32)
# total_rows = len(feats_df)
# idx = self._choose_indices(total_rows, rng)
# feats = torch.from_numpy(feats_df.to_numpy()[idx]).float()
# images.append(feats)
# coords_list.append(None)
# elif train_image_file.endswith('.pt'):
# if not os.path.exists(train_image_file):
# raise FileNotFoundError(train_image_file)
# feats_np = torch.load(train_image_file, map_location='cpu')
# if isinstance(feats_np, torch.Tensor):
# feats_np = feats_np.cpu().numpy()
# feats_np = feats_np.astype(np.float32, copy=False)
# total_rows = feats_np.shape[0]
# idx = self._choose_indices(total_rows, rng)
# feats = torch.from_numpy(feats_np[idx]).float()
# images.append(feats)
# coords_list.append(None)
# elif train_image_file.endswith('.h5'):
# if not os.path.exists(train_image_file):
# raise FileNotFoundError(train_image_file)
# with h5py.File(train_image_file, 'r') as f:
# feats_np = f['features'][:]
# coords_np = f['coords'][:]
# if feats_np.shape[0] != coords_np.shape[0]:
# raise ValueError(
# f"Mismatch rows in features ({feats_np.shape[0]}) vs coords ({coords_np.shape[0]}) "
# f"for {train_image_file}")
# feats_np = feats_np.astype(np.float32, copy=False)
# total_rows = feats_np.shape[0]
# idx = self._choose_indices(total_rows, rng)
# feats = torch.from_numpy(feats_np[idx]).float()
# coords = torch.from_numpy(coords_np[idx]).long()
# images.append(feats)
# coords_list.append(coords)
# else:
# raise ValueError(f'Unsupported file: {train_image_file}')
# data_dict['pixel_values'] = images
# if any(c is not None for c in coords_list):
# coords_list = [c if c is not None else torch.empty(0, 2, dtype=torch.long)
# for c in coords_list]
# data_dict['coords'] = coords_list
# return data_dict
from __future__ import annotations
import json, logging, os
import numpy as np
import pandas as pd
import torch, h5py
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from mmengine import print_log
from torch.utils.data import Dataset, get_worker_info
from xtuner.registry import BUILDER
from .huggingface import process_hf_dataset
class LLaVADataset(Dataset):
def __init__(self,
image_folder,
image_path_list,
per_image_length,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
sample_num=10240,
image_feature_prefix='',
identifier='',
image_feature_suffix='.pt',
unwanted_prefix_csv=None,
sample_strategy='linspace',
# ---------- DEBUG ----------
debug_max_samples=None,
debug_ratio=None,
debug_shuffle=True,
debug_seed=3407,
debug_include_ids=None):
super().__init__()
self.sample_num = int(sample_num)
self.per_image_length = int(per_image_length)
self.pad_image_to_square = pad_image_to_square
self.image_feature_prefix = image_feature_prefix
self.identifier = identifier
self.sample_strategy = sample_strategy # 'linspace' | 'random' | 'random_full'
# debug opts
self._dbg_max = debug_max_samples
self._dbg_ratio = debug_ratio
self._dbg_shuffle = debug_shuffle
self._dbg_seed = int(debug_seed)
self._dbg_include_ids = set(debug_include_ids) if debug_include_ids else None
assert offline_processed_text_folder or (data_path and tokenizer)
if offline_processed_text_folder and data_path:
print_log(
'Both `offline_processed_text_folder` and `data_path` are set, '
'and we load dataset from `offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current', level=logging.WARNING)
# ---------------------- load text ----------------------
if offline_processed_text_folder is not None:
ds = load_from_disk(offline_processed_text_folder)
if isinstance(ds, DatasetDict):
ds = ds.get('train', None) or next(iter(ds.values()))
assert isinstance(ds, HFDataset)
text_ds = ds
text_ds = self._apply_debug_subset_to_hf(text_ds)
self.text_data = text_ds
else:
if data_path.endswith('.json'):
json_data = json.load(open(data_path))
elif data_path.endswith('.jsonl'):
json_data = self._load_jsonl(data_path)
else:
raise NotImplementedError
# ---- unwanted prefixes
unwanted_prefixes = self._load_unwanted_prefixes(unwanted_prefix_csv)
original_count = len(json_data)
filtered = []
for item in json_data:
imgs = item.get('image', [])
if isinstance(imgs, str):
imgs = [imgs]
keep = True
for img in imgs:
if any(pref in img for pref in unwanted_prefixes):
keep = False
break
if keep:
filtered.append(item)
json_data = filtered
print_log(f'Filtered out {original_count - len(json_data)} samples.', logger='current')
# ---- debug include_ids
if self._dbg_include_ids:
keep = [it for it in json_data if str(it.get('id')) in self._dbg_include_ids]
print_log(f'[DEBUG] include_ids -> keep {len(keep)}/{len(json_data)}', logger='current')
json_data = keep
# ---- debug subset
json_data = self._apply_debug_subset_to_list(json_data)
# id -> str
for idx in range(len(json_data)):
if isinstance(json_data[idx].get('id'), int):
json_data[idx]['id'] = str(json_data[idx]['id'])
# HF map & template
json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
self.text_data = process_hf_dataset(
dataset=json_data,
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
split='train',
max_dataset_length=max_dataset_length,
remove_unused_columns=False,
pack_to_max_length=False,
with_image_token=True,
per_image_length=self.per_image_length)
# ---------------------- image feature suffix sanity ----------------------
if image_feature_suffix not in ['.csv', '.pt', '.h5']:
raise ValueError(
f'Unsupported image feature suffix: {image_feature_suffix}. '
'Supported suffixes are: .csv, .pt, .h5')
self.image_feature_suffix = image_feature_suffix
self.image_folder = image_folder
self.image_path_list = image_path_list
# ---------------------- helpers ----------------------
def _load_unwanted_prefixes(self, csv_path):
unwanted_prefixes = set()
if csv_path and os.path.exists(csv_path):
print_log(f'Loading unwanted prefixes from: {csv_path}', logger='current')
try:
df = pd.read_csv(csv_path)
unwanted_prefixes = set(df.iloc[:, 0].astype(str).tolist())
print_log(f'Loaded {len(unwanted_prefixes)} prefixes to filter out.', logger='current')
except Exception as e:
print_log(f'Could not read CSV file {csv_path}. Error: {e}',
logger='current', level=logging.ERROR)
print_log('Falling back to hardcoded list.', logger='current', level=logging.WARNING)
if not unwanted_prefixes:
print_log('Using hardcoded unwanted prefix list.', logger='current', level=logging.WARNING)
unwanted_prefixes = {
"TCGA-HT-7476-01Z-00-DX2", "TCGA-44-7661-01Z-00-DX1", "TCGA-DB-A64V-01Z-00-DX1",
"TCGA-CS-4938-01Z-00-DX1", "TCGA-DB-5273-01Z-00-DX2", "TCGA-DB-5278-01Z-00-DX1",
"TCGA-DB-A4XA-01Z-00-DX1", "TCGA-DB-A4XB-01Z-00-DX1", "TCGA-DB-A4XC-01Z-00-DX2",
"TCGA-DU-5849-01Z-00-DX1", "TCGA-DU-6399-01Z-00-DX1", "TCGA-DU-7006-01Z-00-DX1",
"TCGA-DU-7013-01Z-00-DX1", "TCGA-DU-8165-01Z-00-DX1", "TCGA-DU-A76O-01Z-00-DX1",
"TCGA-DU-A7TG-01Z-00-DX1", "TCGA-E1-A7YM-01Z-00-DX1", "TCGA-E1-A7Z6-01Z-00-DX1",
"TCGA-FG-A6J3-01Z-00-DX2", "TCGA-HT-7467-01Z-00-DX2", "TCGA-HT-7468-01Z-00-DX6",
"TCGA-HT-7470-01Z-00-DX4", "TCGA-HT-7470-01Z-00-DX9", "TCGA-HT-7473-01Z-00-DX2",
"TCGA-HT-7475-01Z-00-DX5", "TCGA-HT-7481-01Z-00-DX1", "TCGA-HT-7482-01Z-00-DX6",
"TCGA-HT-7601-01Z-00-DX3", "TCGA-HT-7607-01Z-00-DX10", "TCGA-HT-7608-01Z-00-DX2",
"TCGA-HT-7616-01Z-00-DX1", "TCGA-HT-7684-01Z-00-DX2", "TCGA-HT-7689-01Z-00-DX1",
"TCGA-HT-7690-01Z-00-DX4", "TCGA-HT-7855-01Z-00-DX1", "TCGA-HT-7856-01Z-00-DX6",
"TCGA-HT-7874-01Z-00-DX2", "TCGA-HT-8105-01Z-00-DX1", "TCGA-HT-8108-01Z-00-DX1",
"TCGA-HT-A74O-01Z-00-DX1", "TCGA-IK-8125-01Z-00-DX1", "TCGA-P5-A72X-01Z-00-DX1",
"TCGA-QH-A65R-01Z-00-DX1", "TCGA-QH-A870-01Z-00-DX1", "TCGA-R8-A6MO-01Z-00-DX7",
"TCGA-S9-A6TX-01Z-00-DX1", "TCGA-TM-A84I-01Z-00-DX1", "TCGA-TM-A84L-01Z-00-DX1",
"TCGA-TM-A84O-01Z-00-DX1", "TCGA-TQ-A7RP-01Z-00-DX1", "TCGA-VM-A8C8-01Z-00-DX8",
"TCGA-VM-A8C9-01Z-00-DX9", "TCGA-VM-A8CA-01Z-00-DX4", "TCGA-VM-A8CB-01Z-00-DX4",
"TCGA-VM-A8CB-01Z-00-DX5", "TCGA-VM-A8CD-01Z-00-DX6", "TCGA-VM-A8CE-01Z-00-DX1",
"TCGA-VM-A8CE-01Z-00-DX7", "TCGA-QK-A8ZB-01Z-00-DX1"
}
return unwanted_prefixes
def _load_jsonl(self, json_file):
with open(json_file) as f:
return [json.loads(line) for line in f]
def _apply_debug_subset_to_list(self, items):
if not items:
return items
n_before = len(items)
if self._dbg_include_ids:
items = [it for it in items if str(it.get('id')) in self._dbg_include_ids]
n_before = len(items)
print_log(f'[DEBUG] include_ids -> keep {n_before}', logger='current')
if self._dbg_max is None and self._dbg_ratio is not None:
self._dbg_max = max(1, int(round(n_before * float(self._dbg_ratio))))
if self._dbg_max is None:
print_log('[DEBUG] dataset full size used.', logger='current')
return items
k = min(int(self._dbg_max), n_before)
if k <= 0:
return items
if self._dbg_shuffle:
rng = np.random.default_rng(self._dbg_seed)
idx = rng.choice(n_before, size=k, replace=False)
idx = sorted(idx.tolist())
items = [items[i] for i in idx]
else:
items = items[:k]
print_log(f'[DEBUG] subset: {len(items)}/{n_before} samples used '
f'({"random" if self._dbg_shuffle else "head"}).',
logger='current')
return items
def _apply_debug_subset_to_hf(self, ds: HFDataset) -> HFDataset:
n_before = ds.num_rows
if self._dbg_include_ids:
keep_idx = [i for i, ex in enumerate(ds) if str(ex.get('id')) in self._dbg_include_ids]
ds = ds.select(keep_idx)
print_log(f'[DEBUG] include_ids -> keep {ds.num_rows}/{n_before}', logger='current')
n_before = ds.num_rows
if self._dbg_max is None and self._dbg_ratio is not None:
self._dbg_max = max(1, int(round(n_before * float(self._dbg_ratio))))
if self._dbg_max is None:
print_log('[DEBUG] dataset full size used (offline).', logger='current')
return ds
k = min(int(self._dbg_max), n_before)
if k <= 0:
return ds
if self._dbg_shuffle:
rng = np.random.default_rng(self._dbg_seed)
idx = rng.choice(n_before, size=k, replace=False)
idx = sorted(idx.tolist())
else:
idx = list(range(k))
ds = ds.select(idx)
print_log(f'[DEBUG] subset (offline): {ds.num_rows}/{n_before} samples used '
f'({"random" if self._dbg_shuffle else "head"}).',
logger='current')
return ds
# -------- per-worker RNG --------
def _rng(self):
wi = get_worker_info()
base = self._dbg_seed
if wi is None:
seed = (base ^ (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF
else:
seed = (base + wi.id + (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF
return np.random.default_rng(seed)
# -------- path parsing --------
def _parse_stub(self, image_path: str):
norm = os.path.normpath(image_path)
parts = norm.split(os.sep)
if len(parts) < 2:
fname = os.path.splitext(parts[-1])[0]
tumor = fname.split('-')[0].lower() if '-' in fname else 'unknown'
case = fname
else:
tumor = parts[-2].lower()
case = os.path.splitext(parts[-1])[0]
return tumor, case
def _build_feature_path(self, tumor_name: str, case_name: str):
if self.image_feature_suffix == ".pt":
subdir = "pt_files"
elif self.image_feature_suffix == ".csv":
subdir = "csv_files"
elif self.image_feature_suffix == ".h5":
subdir = "h5_files"
else:
raise ValueError(f"Unknown feature suffix: {self.image_feature_suffix}")
return os.path.join(
self.image_feature_prefix,
f"{tumor_name}{self.identifier}",
subdir,
case_name + self.image_feature_suffix
)
# -------- choose patch indices --------
def _choose_indices(self, total_rows: int, rng: np.random.Generator):
k = self.sample_num
if total_rows <= 0:
return np.array([], dtype=int)
if self.sample_strategy == "random_full":
# Always exactly k rows; with replacement if needed
replace = total_rows < k
idx = rng.choice(total_rows, size=k, replace=replace)
return np.sort(idx.astype(int))
if self.sample_strategy == "random":
if total_rows <= k:
return np.arange(total_rows, dtype=int)
idx = rng.choice(total_rows, size=k, replace=False)
return np.sort(idx.astype(int))
# linspace
if total_rows <= k:
return np.arange(total_rows, dtype=int)
step = total_rows / k
jitter = int(rng.integers(0, max(1, int(step))))
indices = (np.floor(np.arange(k) * step + jitter)).astype(int)
return np.clip(indices, 0, total_rows - 1)
# ---------------------- rest of class ----------------------
@property
def modality_length(self):
length_list = []
for data_dict in self.text_data:
cur_len = len(data_dict['input_ids'])
image = data_dict.get('image', None)
if image is None:
cur_len = -cur_len
else:
n_images = 1 if isinstance(image, str) else len(image)
cur_len = cur_len - n_images + self.per_image_length * n_images
length_list.append(cur_len)
return length_list
def __len__(self):
return len(self.text_data)
def __getitem__(self, index):
data_dict = self.text_data[index]
if data_dict.get('image', None) is None:
return data_dict
image_list = data_dict['image']
if isinstance(image_list, str):
image_list = [image_list]
images, coords_list = [], []
rng = self._rng()
for image_file in image_list:
tumor_name, case_name = self._parse_stub(image_file)
train_image_file = self._build_feature_path(tumor_name, case_name)
if train_image_file.endswith('.csv'):
if not os.path.exists(train_image_file):
raise FileNotFoundError(train_image_file)
feats_df = pd.read_csv(train_image_file, usecols=range(512), dtype=np.float32)
total_rows = len(feats_df)
idx = self._choose_indices(total_rows, rng)
feats = torch.from_numpy(feats_df.to_numpy()[idx]).float()
images.append(feats)
coords_list.append(None)
elif train_image_file.endswith('.pt'):
if not os.path.exists(train_image_file):
raise FileNotFoundError(train_image_file)
feats_np = torch.load(train_image_file, map_location='cpu')
if isinstance(feats_np, torch.Tensor):
feats_np = feats_np.cpu().numpy()
feats_np = feats_np.astype(np.float32, copy=False)
total_rows = feats_np.shape[0]
idx = self._choose_indices(total_rows, rng)
feats = torch.from_numpy(feats_np[idx]).float()
images.append(feats)
coords_list.append(None)
elif train_image_file.endswith('.h5'):
if not os.path.exists(train_image_file):
raise FileNotFoundError(train_image_file)
with h5py.File(train_image_file, 'r') as f:
feats_np = f['features'][:]
coords_np = f['coords'][:]
if feats_np.shape[0] != coords_np.shape[0]:
raise ValueError(
f"Mismatch rows in features ({feats_np.shape[0]}) vs coords ({coords_np.shape[0]}) "
f"for {train_image_file}")
feats_np = feats_np.astype(np.float32, copy=False)
total_rows = feats_np.shape[0]
idx = self._choose_indices(total_rows, rng)
feats = torch.from_numpy(feats_np[idx]).float()
coords = torch.from_numpy(coords_np[idx]).long()
images.append(feats)
coords_list.append(coords)
else:
raise ValueError(f'Unsupported file: {train_image_file}')
data_dict['pixel_values'] = images
if any(c is not None for c in coords_list):
coords_list = [c if c is not None else torch.empty(0, 2, dtype=torch.long)
for c in coords_list]
data_dict['coords'] = coords_list
return data_dict