Addressed issues for tokenizer, anndata tokenizer now uses a fraction of memory
Browse files- geneformer/tokenizer.py +46 -30
geneformer/tokenizer.py
CHANGED
|
@@ -27,6 +27,7 @@ warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
|
|
| 27 |
import anndata as ad
|
| 28 |
import loompy as lp
|
| 29 |
import numpy as np
|
|
|
|
| 30 |
from datasets import Dataset
|
| 31 |
|
| 32 |
logger = logging.getLogger(__name__)
|
|
@@ -35,6 +36,15 @@ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
|
|
| 35 |
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
| 36 |
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def tokenize_cell(gene_vector, gene_tokens):
|
| 39 |
"""
|
| 40 |
Convert normalized gene expression vector to tokenized rank value encoding.
|
|
@@ -42,11 +52,8 @@ def tokenize_cell(gene_vector, gene_tokens):
|
|
| 42 |
# create array of gene vector with token indices
|
| 43 |
# mask undetected genes
|
| 44 |
nonzero_mask = np.nonzero(gene_vector)[0]
|
| 45 |
-
#
|
| 46 |
-
|
| 47 |
-
# tokenize
|
| 48 |
-
sentence_tokens = gene_tokens[nonzero_mask][sorted_indices]
|
| 49 |
-
return sentence_tokens
|
| 50 |
|
| 51 |
|
| 52 |
class TranscriptomeTokenizer:
|
|
@@ -101,6 +108,7 @@ class TranscriptomeTokenizer:
|
|
| 101 |
output_directory: Path | str,
|
| 102 |
output_prefix: str,
|
| 103 |
file_format: Literal["loom", "h5ad"] = "loom",
|
|
|
|
| 104 |
):
|
| 105 |
"""
|
| 106 |
Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
|
|
@@ -115,11 +123,13 @@ class TranscriptomeTokenizer:
|
|
| 115 |
Prefix for output .dataset
|
| 116 |
file_format : str
|
| 117 |
Format of input files. Can be "loom" or "h5ad".
|
|
|
|
|
|
|
| 118 |
"""
|
| 119 |
tokenized_cells, cell_metadata = self.tokenize_files(
|
| 120 |
Path(data_directory), file_format
|
| 121 |
)
|
| 122 |
-
tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata)
|
| 123 |
|
| 124 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
|
| 125 |
tokenized_dataset.save_to_disk(output_path)
|
|
@@ -129,7 +139,7 @@ class TranscriptomeTokenizer:
|
|
| 129 |
):
|
| 130 |
tokenized_cells = []
|
| 131 |
if self.custom_attr_name_dict is not None:
|
| 132 |
-
|
| 133 |
cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
|
| 134 |
|
| 135 |
# loops through directories to tokenize .loom files
|
|
@@ -144,7 +154,7 @@ class TranscriptomeTokenizer:
|
|
| 144 |
file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
|
| 145 |
tokenized_cells += file_tokenized_cells
|
| 146 |
if self.custom_attr_name_dict is not None:
|
| 147 |
-
for k in
|
| 148 |
cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
|
| 149 |
else:
|
| 150 |
cell_metadata = None
|
|
@@ -155,8 +165,8 @@ class TranscriptomeTokenizer:
|
|
| 155 |
raise
|
| 156 |
return tokenized_cells, cell_metadata
|
| 157 |
|
| 158 |
-
def tokenize_anndata(self, adata_file_path):
|
| 159 |
-
adata = ad.read(adata_file_path)
|
| 160 |
file_cell_metadata = {
|
| 161 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
| 162 |
}
|
|
@@ -176,7 +186,7 @@ class TranscriptomeTokenizer:
|
|
| 176 |
)
|
| 177 |
|
| 178 |
try:
|
| 179 |
-
adata.obs["filter_pass"]
|
| 180 |
except KeyError:
|
| 181 |
var_exists = False
|
| 182 |
else:
|
|
@@ -193,24 +203,26 @@ class TranscriptomeTokenizer:
|
|
| 193 |
filter_pass_loc = np.array([i for i in range(adata.shape[0])])
|
| 194 |
|
| 195 |
tokenized_cells = []
|
| 196 |
-
adata_filter = adata[
|
| 197 |
-
filter_pass_loc, coding_miRNA_loc # filter cells and genes
|
| 198 |
-
]
|
| 199 |
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
|
| 211 |
return tokenized_cells, file_cell_metadata
|
| 212 |
|
| 213 |
-
def tokenize_file(self, loom_file_path):
|
| 214 |
if self.custom_attr_name_dict is not None:
|
| 215 |
file_cell_metadata = {
|
| 216 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
|
@@ -261,7 +273,7 @@ class TranscriptomeTokenizer:
|
|
| 261 |
subview_norm_array = (
|
| 262 |
subview[:, :]
|
| 263 |
/ subview.ca.n_counts
|
| 264 |
-
*
|
| 265 |
/ norm_factor_vector[:, None]
|
| 266 |
)
|
| 267 |
# tokenize subview gene vectors
|
|
@@ -279,21 +291,25 @@ class TranscriptomeTokenizer:
|
|
| 279 |
|
| 280 |
return tokenized_cells, file_cell_metadata
|
| 281 |
|
| 282 |
-
def create_dataset(self, tokenized_cells, cell_metadata):
|
|
|
|
| 283 |
# create dict for dataset creation
|
| 284 |
dataset_dict = {"input_ids": tokenized_cells}
|
| 285 |
if self.custom_attr_name_dict is not None:
|
| 286 |
dataset_dict.update(cell_metadata)
|
| 287 |
|
| 288 |
# create dataset
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
# truncate dataset
|
| 295 |
def truncate(example):
|
| 296 |
-
example["input_ids"] = example["input_ids"][
|
| 297 |
return example
|
| 298 |
|
| 299 |
output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
|
|
|
|
| 27 |
import anndata as ad
|
| 28 |
import loompy as lp
|
| 29 |
import numpy as np
|
| 30 |
+
import scipy.sparse as sp
|
| 31 |
from datasets import Dataset
|
| 32 |
|
| 33 |
logger = logging.getLogger(__name__)
|
|
|
|
| 36 |
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
| 37 |
|
| 38 |
|
| 39 |
+
def rank_genes(gene_vector, gene_tokens):
|
| 40 |
+
"""
|
| 41 |
+
Rank gene expression vector.
|
| 42 |
+
"""
|
| 43 |
+
# sort by median-scaled gene values
|
| 44 |
+
sorted_indices = np.argsort(-gene_vector)
|
| 45 |
+
return gene_tokens[sorted_indices]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
def tokenize_cell(gene_vector, gene_tokens):
|
| 49 |
"""
|
| 50 |
Convert normalized gene expression vector to tokenized rank value encoding.
|
|
|
|
| 52 |
# create array of gene vector with token indices
|
| 53 |
# mask undetected genes
|
| 54 |
nonzero_mask = np.nonzero(gene_vector)[0]
|
| 55 |
+
# rank by median-scaled gene values
|
| 56 |
+
return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
class TranscriptomeTokenizer:
|
|
|
|
| 108 |
output_directory: Path | str,
|
| 109 |
output_prefix: str,
|
| 110 |
file_format: Literal["loom", "h5ad"] = "loom",
|
| 111 |
+
use_generator: bool = False,
|
| 112 |
):
|
| 113 |
"""
|
| 114 |
Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
|
|
|
|
| 123 |
Prefix for output .dataset
|
| 124 |
file_format : str
|
| 125 |
Format of input files. Can be "loom" or "h5ad".
|
| 126 |
+
use_generator : bool
|
| 127 |
+
Whether to use generator or dict for tokenization.
|
| 128 |
"""
|
| 129 |
tokenized_cells, cell_metadata = self.tokenize_files(
|
| 130 |
Path(data_directory), file_format
|
| 131 |
)
|
| 132 |
+
tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata, use_generator=use_generator)
|
| 133 |
|
| 134 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
|
| 135 |
tokenized_dataset.save_to_disk(output_path)
|
|
|
|
| 139 |
):
|
| 140 |
tokenized_cells = []
|
| 141 |
if self.custom_attr_name_dict is not None:
|
| 142 |
+
cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
|
| 143 |
cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
|
| 144 |
|
| 145 |
# loops through directories to tokenize .loom files
|
|
|
|
| 154 |
file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
|
| 155 |
tokenized_cells += file_tokenized_cells
|
| 156 |
if self.custom_attr_name_dict is not None:
|
| 157 |
+
for k in cell_attr:
|
| 158 |
cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
|
| 159 |
else:
|
| 160 |
cell_metadata = None
|
|
|
|
| 165 |
raise
|
| 166 |
return tokenized_cells, cell_metadata
|
| 167 |
|
| 168 |
+
def tokenize_anndata(self, adata_file_path, target_sum=10_000, chunk_size=512):
|
| 169 |
+
adata = ad.read(adata_file_path, backed="r")
|
| 170 |
file_cell_metadata = {
|
| 171 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
| 172 |
}
|
|
|
|
| 186 |
)
|
| 187 |
|
| 188 |
try:
|
| 189 |
+
_ = adata.obs["filter_pass"]
|
| 190 |
except KeyError:
|
| 191 |
var_exists = False
|
| 192 |
else:
|
|
|
|
| 203 |
filter_pass_loc = np.array([i for i in range(adata.shape[0])])
|
| 204 |
|
| 205 |
tokenized_cells = []
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
+
for i in range(0, len(filter_pass_loc), chunk_size):
|
| 208 |
+
idx = filter_pass_loc[i:i+chunk_size]
|
| 209 |
+
X = adata[idx].X
|
| 210 |
+
|
| 211 |
+
X_norm = (X / X[:, coding_miRNA_loc].sum(axis=1) * target_sum / norm_factor_vector)
|
| 212 |
+
X_norm = sp.csr_matrix(X_norm)
|
| 213 |
|
| 214 |
+
tokenized_cells += [
|
| 215 |
+
rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
|
| 216 |
+
for i in range(X_norm.shape[0])
|
| 217 |
+
]
|
| 218 |
|
| 219 |
+
# add custom attributes for subview to dict
|
| 220 |
+
for k in file_cell_metadata.keys():
|
| 221 |
+
file_cell_metadata[k] += adata[idx].obs[k].tolist()
|
| 222 |
|
| 223 |
return tokenized_cells, file_cell_metadata
|
| 224 |
|
| 225 |
+
def tokenize_file(self, loom_file_path, target_sum=10_000):
|
| 226 |
if self.custom_attr_name_dict is not None:
|
| 227 |
file_cell_metadata = {
|
| 228 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
|
|
|
| 273 |
subview_norm_array = (
|
| 274 |
subview[:, :]
|
| 275 |
/ subview.ca.n_counts
|
| 276 |
+
* target_sum
|
| 277 |
/ norm_factor_vector[:, None]
|
| 278 |
)
|
| 279 |
# tokenize subview gene vectors
|
|
|
|
| 291 |
|
| 292 |
return tokenized_cells, file_cell_metadata
|
| 293 |
|
| 294 |
+
def create_dataset(self, tokenized_cells, cell_metadata, use_generator=False):
|
| 295 |
+
print("Creating dataset...")
|
| 296 |
# create dict for dataset creation
|
| 297 |
dataset_dict = {"input_ids": tokenized_cells}
|
| 298 |
if self.custom_attr_name_dict is not None:
|
| 299 |
dataset_dict.update(cell_metadata)
|
| 300 |
|
| 301 |
# create dataset
|
| 302 |
+
if use_generator:
|
| 303 |
+
def dict_generator():
|
| 304 |
+
for i in range(len(tokenized_cells)):
|
| 305 |
+
yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
|
| 306 |
+
output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
|
| 307 |
+
else:
|
| 308 |
+
output_dataset = Dataset.from_dict(dataset_dict)
|
| 309 |
|
| 310 |
# truncate dataset
|
| 311 |
def truncate(example):
|
| 312 |
+
example["input_ids"] = example["input_ids"][:2048]
|
| 313 |
return example
|
| 314 |
|
| 315 |
output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
|