sfaira.estimators.EstimatorKerasCelltype

class sfaira.estimators.EstimatorKerasCelltype(data: typing.Union[anndata._core.anndata.AnnData, sfaira.data.store.stores.single.StoreSingleFeatureSpace], model_dir: typing.Optional[str], model_id: typing.Optional[str], model_topology: sfaira.versions.topologies.class_interface.TopologyContainer, weights_md5: typing.Optional[str] = None, cache_path: str = 'cache/', celltype_ontology: typing.Optional[sfaira.versions.metadata.base.OntologyObo] = None, max_class_weight: float = 1000.0, remove_unlabeled_cells: bool = True, adata_ids: sfaira.consts.adata_fields.AdataIds = <sfaira.consts.adata_fields.AdataIdsSfaira object>)

Estimator class for the cell type model.

Attributes

model_type

ntypes

obs

obs_eval

obs_test

obs_train

ontology_ids

ontology_names

organism

using_store

celltype_universe

model

Methods

compute_gradients_input([test_data, ...])

evaluate([batch_size, max_steps, weighted])

Evaluate the custom model on local data.

evaluate_any(idx[, batch_size, max_steps, ...])

Evaluate the custom model on any local data.

get_one_time_tf_dataset(idx, mode[, ...])

init_model([clear_weight_cache, ...])

instantiate the model :return:

load_pretrained_weights()

Loads model weights from local directory or zenodo.

load_weights_from_cache()

predict([batch_size, max_steps])

Return the prediction of the model

save_weights_to_cache()

split_train_val_test(val_split, test_split)

Split indices in store into train, valiation and test split.

train(optimizer, lr[, epochs, ...])

Train model.

ytrue([batch_size, max_steps])

Return the true labels of the test set.