sfaira.estimators.EstimatorKerasCelltype¶
- class sfaira.estimators.EstimatorKerasCelltype(data: typing.Union[anndata._core.anndata.AnnData, sfaira.data.store.stores.single.StoreSingleFeatureSpace], model_dir: typing.Optional[str], model_id: typing.Optional[str], model_topology: sfaira.versions.topologies.class_interface.TopologyContainer, weights_md5: typing.Optional[str] = None, cache_path: str = 'cache/', celltype_ontology: typing.Optional[sfaira.versions.metadata.base.OntologyObo] = None, max_class_weight: float = 1000.0, remove_unlabeled_cells: bool = True, adata_ids: sfaira.consts.adata_fields.AdataIds = <sfaira.consts.adata_fields.AdataIdsSfaira object>)¶
Estimator class for the cell type model.
Attributes
Methods
compute_gradients_input
([test_data, ...])evaluate
([batch_size, max_steps, weighted])Evaluate the custom model on local data.
evaluate_any
(idx[, batch_size, max_steps, ...])Evaluate the custom model on any local data.
get_one_time_tf_dataset
(idx, mode[, ...])init_model
([clear_weight_cache, ...])instantiate the model :return:
Loads model weights from local directory or zenodo.
predict
([batch_size, max_steps])Return the prediction of the model
split_train_val_test
(val_split, test_split)Split indices in store into train, valiation and test split.
train
(optimizer, lr[, epochs, ...])Train model.
ytrue
([batch_size, max_steps])Return the true labels of the test set.