sfaira.estimators.EstimatorKerasEmbedding

class sfaira.estimators.EstimatorKerasEmbedding(data: typing.Union[anndata._core.anndata.AnnData, numpy.ndarray, sfaira.data.store.stores.single.StoreSingleFeatureSpace], model_dir: typing.Optional[str], model_id: typing.Optional[str], model_topology: sfaira.versions.topologies.class_interface.TopologyContainer, weights_md5: typing.Optional[str] = None, cache_path: str = 'cache/', adata_ids: sfaira.consts.adata_fields.AdataIds = <sfaira.consts.adata_fields.AdataIdsSfaira object>)

Estimator class for the embedding model.

Attributes

model_type

obs

obs_eval

obs_test

obs_train

organism

using_store

model

Methods

compute_gradients_input([batch_size, ...])

evaluate([batch_size, max_steps])

Evaluate the custom model on test data.

evaluate_any(idx[, batch_size, max_steps])

Evaluate the custom model on any local data.

get_one_time_tf_dataset(idx, mode[, ...])

init_model([clear_weight_cache, ...])

instantiate the model :return:

load_pretrained_weights()

Loads model weights from local directory or zenodo.

load_weights_from_cache()

predict([batch_size])

return the prediction of the model

predict_embedding([batch_size, variational])

return the prediction in the latent space (z_mean for variational models)

save_weights_to_cache()

split_train_val_test(val_split, test_split)

Split indices in store into train, valiation and test split.

train(optimizer, lr[, epochs, ...])

Train model.