sfaira.estimators.EstimatorKerasEmbedding¶
- class sfaira.estimators.EstimatorKerasEmbedding(data: typing.Union[anndata._core.anndata.AnnData, numpy.ndarray, sfaira.data.store.stores.single.StoreSingleFeatureSpace], model_dir: typing.Optional[str], model_id: typing.Optional[str], model_topology: sfaira.versions.topologies.class_interface.TopologyContainer, weights_md5: typing.Optional[str] = None, cache_path: str = 'cache/', adata_ids: sfaira.consts.adata_fields.AdataIds = <sfaira.consts.adata_fields.AdataIdsSfaira object>)¶
Estimator class for the embedding model.
Attributes
Methods
compute_gradients_input
([batch_size, ...])evaluate
([batch_size, max_steps])Evaluate the custom model on test data.
evaluate_any
(idx[, batch_size, max_steps])Evaluate the custom model on any local data.
get_one_time_tf_dataset
(idx, mode[, ...])init_model
([clear_weight_cache, ...])instantiate the model :return:
Loads model weights from local directory or zenodo.
predict
([batch_size])return the prediction of the model
predict_embedding
([batch_size, variational])return the prediction in the latent space (z_mean for variational models)
split_train_val_test
(val_split, test_split)Split indices in store into train, valiation and test split.
train
(optimizer, lr[, epochs, ...])Train model.