import os, re, math, json, argparse, warnings from pathlib import Path from typing import List, Tuple, Dict, Optional import numpy as np, h5py, pyedflib # ---------- 10-20 electrode 2D coords (x,y) đã chuẩn hoá tương đối ---------- E1020 = { "FP1":(-0.5, 1.0), "FP2":( 0.5, 1.0), "F7": (-1.0, 0.6), "F3":(-0.3, 0.6), "FZ":(0.0,0.8), "F4":( 0.3, 0.6), "F8":( 1.0, 0.6), "T3":(-1.2, 0.2), "T7":(-1.2, 0.2), "C3":(-0.4, 0.2), "CZ":(0.0,0.3), "C4":( 0.4, 0.2), "T4":( 1.2, 0.2), "T8":( 1.2, 0.2), "T5":(-1.1, -0.2), "P7":(-1.1, -0.2), "P3":(-0.3, -0.2), "PZ":(0.0,-0.1), "P4":( 0.3, -0.2), "T6":( 1.1, -0.2), "P8":( 1.1, -0.2), "O1":(-0.5, -0.8), "O2":( 0.5, -0.8) } ALIASES = {"T3":"T7","T4":"T8","T5":"P7","T6":"P8"} # đồng nhất tên _BLOCKLIST = re.compile(r'(?:ECG|VNS|-$)', re.IGNORECASE) def _norm_e(name:str)->Optional[str]: name=name.upper().strip() name=ALIASES.get(name,name) return name if name in E1020 else None def pair_midpoint(ch_label:str)->Optional[Tuple[float,float]]: """ CHB-MIT thường dùng kênh lưỡng cực 'A-B'. Lấy toạ độ là trung điểm của hai điện cực A,B. """ if '-' not in ch_label: return None a,b = [x.strip().upper() for x in ch_label.split('-',1)] a=_norm_e(a); b=_norm_e(b) if (a is None) or (b is None): return None ax,ay=E1020[a]; bx,by=E1020[b] return ((ax+bx)/2.0, (ay+by)/2.0) def build_1020_graph(channels:List[str], k:int=8, sigma:float=0.5, radius:float=1.6): """ Đồ thị khoảng cách tĩnh theo 10–20: w_ij = exp(-||vi-vj||^2 / sigma^2), nếu ||vi-vj|| <= radius, ngược lại 0. Sau đó với mỗi nút, giữ lại k láng giềng w lớn nhất (undirected). """ coords=[] keep=[] for ch in channels: if _BLOCKLIST.search(ch): continue m=pair_midpoint(ch) if m is not None: coords.append(m); keep.append(ch) C=len(keep) if C<2: raise RuntimeError("Không đủ kênh ánh xạ được sang 10-20") coords=np.array(coords, dtype=np.float32) D=np.sqrt(((coords[None,:,:]-coords[:,None,:])**2).sum(axis=-1)) # (C,C) W=np.exp(-(D**2)/(sigma**2)) W[D>radius]=0.0 np.fill_diagonal(W, 0.0) edges=set() for i in range(C): idx=np.argsort(-W[i])[:max(1,min(k,C-1))] for j in idx: a,b=(i,j) if i0: edges.add((a,b)) edges=sorted(list(edges)) ei=np.array(edges, dtype=np.int64).T ei=np.hstack([ei, ei[::-1,:]]) if ei.size else ei ew=np.array([W[i,j] for (i,j) in edges], dtype=np.float32) ew=np.concatenate([ew, ew]) if ew.size else ew return keep, ei, ew # ---------- EDF utils ---------- def list_edf_files(patient_dir: Path): return sorted([p for p in patient_dir.glob("*.edf") if p.is_file()]) def edf_channel_labels(edf_path: Path): f=pyedflib.EdfReader(str(edf_path)) labels=[f.getLabel(i).strip() for i in range(f.signals_in_file)] f._close(); del f return [ch for ch in labels if not _BLOCKLIST.search(ch)] def intersection_channels(edf_paths): common=None for p in edf_paths: chans=set(edf_channel_labels(p)) if not chans: continue common = chans if common is None else (common & chans) return sorted(list(common)) if common else [] def read_edf_signals(edf_path: Path, keep_channels): f=pyedflib.EdfReader(str(edf_path)) labels=[f.getLabel(i).strip() for i in range(f.signals_in_file)] fs_all=[int(round(f.getSampleFrequency(i))) for i in range(f.signals_in_file)] fs=int(round(np.median(fs_all))) if fs_all else 256 idxs=[]; out_labels=[] for ch in keep_channels: try: i=labels.index(ch) if _BLOCKLIST.search(ch): continue idxs.append(i); out_labels.append(ch) except ValueError: f._close(); del f raise RuntimeError(f"Channel {ch} not found in {edf_path.name}") sigs=np.vstack([f.readSignal(i) for i in idxs]).astype(np.float32) # (C,N) f._close(); del f return sigs, fs, out_labels # ---------- nhãn co giật ---------- def parse_summary(summary_path: Path) -> Dict[str, List[Tuple[float,float]]]: mapping={} if not (summary_path and summary_path.exists()): return mapping curr=None; buf=[] with summary_path.open("r", errors="ignore") as f: for line in f: line=line.strip() mfile=re.search(r'File Name:\s*(\S+\.edf)', line, re.IGNORECASE) if mfile: if curr and buf: pairs=[(buf[i], buf[i+1]) for i in range(0,len(buf)-1,2)] mapping.setdefault(curr, []).extend(pairs) curr=mfile.group(1); buf=[]; continue if re.search(r'Seizure (Start|End) Time', line, re.IGNORECASE): nums=[float(x) for x in re.findall(r'[\d.]+', line)] if nums: buf.extend(nums) if curr and buf: pairs=[(buf[i], buf[i+1]) for i in range(0,len(buf)-1,2)] mapping.setdefault(curr, []).extend(pairs) return mapping def parse_seizures_file(seiz_file: Path) -> List[Tuple[float,float]]: intervals=[] if not seiz_file.exists(): return intervals with seiz_file.open("r", errors="ignore") as f: for line in f: nums=[float(x) for x in re.findall(r'[-+]?\d*\.?\d+', line)] if len(nums)>=2: intervals.append((nums[0], nums[1])) return intervals def slice_starts(N, fs, clip_sec, hop_sec): T=int(fs*clip_sec); hop=int(fs*hop_sec) if N=thr: return 1 return 0 class H5Appender: def __init__(self, out_path: Path, C: int, T: int, fs: int, channels, edge_index, edge_weight, gzip_level=4): self.f=h5py.File(str(out_path),"w") self.ds_clips=self.f.create_dataset("clips", shape=(0,C,T,1), maxshape=(None,C,T,1), dtype="float32", chunks=(16,C,T,1), compression="gzip", compression_opts=gzip_level) self.ds_labels=self.f.create_dataset("labels", shape=(0,), maxshape=(None,), dtype="i8", chunks=True, compression="gzip", compression_opts=gzip_level) self.ds_fileids=self.f.create_dataset("file_ids", shape=(0,), maxshape=(None,), dtype="i4", chunks=True, compression="gzip", compression_opts=gzip_level) self.f.attrs["fs"]=fs; self.f.attrs["T"]=T; self.f.attrs["patient"]=out_path.stem self.f.create_dataset("channels", data=np.array([c.encode() for c in channels])) self.f.create_dataset("edge_index", data=edge_index.astype(np.int64)) self.f.create_dataset("edge_weight", data=edge_weight.astype(np.float32)) self.n=0 def append(self, clips_CT, labels, file_id:int): if clips_CT.size==0: return M,C,T=clips_CT.shape clips=clips_CT[...,None].astype(np.float32) self.ds_clips.resize(self.n+M, axis=0) self.ds_labels.resize(self.n+M, axis=0) self.ds_fileids.resize(self.n+M, axis=0) self.ds_clips[self.n:self.n+M]=clips self.ds_labels[self.n:self.n+M]=labels.astype(np.int64) self.ds_fileids[self.n:self.n+M]=np.full((M,), file_id, dtype=np.int32) self.n+=M def close(self): self.f.close() def process_patient_1020(root: Path, patient: str, out_path: Path, clip_sec: float=4.0, hop_sec: float=2.0, fs_target: int=256, min_overlap_sec: float=0.25, graph_k: int=8, sigma: float=0.5, radius: float=1.6): pat_dir=root/patient assert pat_dir.exists(), f"Not found: {pat_dir}" edfs=list_edf_files(pat_dir); assert edfs, f"No EDF in {pat_dir}" # kênh giao nhau trong bệnh nhân keep_channels=intersection_channels(edfs) # chỉ giữ kênh map được sang 10-20 (bỏ kênh lạ) keep_channels=[ch for ch in keep_channels if pair_midpoint(ch) is not None] if len(keep_channels)<8: warnings.warn(f"[{patient}] chỉ còn {len(keep_channels)} kênh sau khi map 10-20") # đồ thị 10-20 tĩnh keep_channels, ei, ew = build_1020_graph(keep_channels, k=graph_k, sigma=sigma, radius=radius) summ = next(iter(list(pat_dir.glob("*summary*.txt"))), None) summ_map=parse_summary(summ) if summ else {} seiz_map={} for p in edfs: iv=parse_seizures_file(p.with_suffix(p.suffix+".seizures")) if iv: seiz_map[p.name]=iv def intervals_for(name): return seiz_map.get(name, summ_map.get(name, [])) app=None; file_id=0 total_pos=0; total=0 for edf in edfs: sigs, fs, chans=read_edf_signals(edf, keep_channels) # resample if fs!=fs_target: ratio=fs_target/fs N_new=int(round(sigs.shape[1]*ratio)) t_old=np.linspace(0, sigs.shape[1]-1, sigs.shape[1], dtype=np.float32) t_new=np.linspace(0, sigs.shape[1]-1, N_new, dtype=np.float32) sigs=np.stack([np.interp(t_new, t_old, ch) for ch in sigs], axis=0).astype(np.float32) C,N=sigs.shape; T=int(fs_target*clip_sec) starts=slice_starts(N, fs_target, clip_sec, hop_sec) if app is None: app=H5Appender(out_path, C=C, T=T, fs=fs_target, channels=chans, edge_index=ei, edge_weight=ew) ivals=intervals_for(edf.name) labels=np.array([label_for_window(int(s), T, fs_target, ivals, min_overlap_sec) for s in starts], dtype=np.int64) M=len(starts); clips=np.empty((M,C,T), dtype=np.float32) for i,s in enumerate(starts): clips[i]=zscore_perclip(sigs[:, s:s+T]) app.append(clips, labels, file_id=file_id) total += M; total_pos += int(labels.sum()); file_id+=1 app.close() print(f"[Done] {patient} -> {out_path} | clips={total} (pos={total_pos}, neg={total-total_pos}), C={len(keep_channels)}, T={int(fs_target*clip_sec)}") if __name__=="__main__": import argparse ap=argparse.ArgumentParser() ap.add_argument("--root", required=True) ap.add_argument("--patient", required=True) ap.add_argument("--out", required=True) ap.add_argument("--clip-sec", type=float, default=4.0) ap.add_argument("--hop-sec", type=float, default=2.0) ap.add_argument("--fs", type=int, default=256) ap.add_argument("--min-overlap", type=float, default=0.25) ap.add_argument("--k", type=int, default=8) ap.add_argument("--sigma", type=float, default=0.5) ap.add_argument("--radius", type=float, default=1.6) args=ap.parse_args() process_patient_1020(Path(args.root), args.patient, Path(args.out), clip_sec=args.clip_sec, hop_sec=args.hop_sec, fs_target=args.fs, min_overlap_sec=args.min_overlap, graph_k=args.k, sigma=args.sigma, radius=args.radius)