sfaira.estimators.EstimatorKeras

class sfaira.estimators.EstimatorKeras

Estimator base class for keras models.

Important: Subclass implementing abstract classes also has to inherit from EstimatorBase class.

Attributes

data

model

weights

model_dir

history

train_hyperparam

idx_train

idx_eval

idx_test

cache_path

model_id

md5

Methods

get_one_time_tf_dataset(idx, mode[, ...])

init_model([clear_weight_cache, ...])

Instantiate the model.

load_pretrained_weights()

Loads model weights from local directory or zenodo.

load_weights_from_cache()

save_weights_to_cache()

split_train_val_test(val_split, test_split)

Split indices in store into train, valiation and test split.

train(optimizer, lr[, epochs, ...])

Train model.