sfaira.estimators.EstimatorKeras.train¶
- EstimatorKeras.train(optimizer: str, lr: float, epochs: int = 1000, max_steps_per_epoch: Optional[int] = 20, batch_size: int = 128, validation_split: float = 0.1, test_split: Union[float, dict] = 0.0, validation_batch_size: int = 256, max_validation_steps: Optional[int] = 10, patience: int = 20, lr_schedule_min_lr: float = 1e-05, lr_schedule_factor: float = 0.2, lr_schedule_patience: int = 5, shuffle_buffer_size: Optional[int] = None, cache_full: bool = False, randomized_batch_access: bool = True, retrieval_batch_size: int = 512, prefetch: Optional[int] = 1, log_dir: Optional[str] = None, callbacks: Optional[list] = None, weighted: bool = False, verbose: int = 2)¶
Train model.
Uses validation loss and maximum number of epochs as termination criteria.
- Parameters
optimizer – str corresponding to tf.keras optimizer to use for fitting.
lr – Learning rate
epochs – refer to tf.keras.models.Model.fit() documentation
max_steps_per_epoch – Maximum steps per epoch.
batch_size – refer to tf.keras.models.Model.fit() documentation
validation_split – refer to tf.keras.models.Model.fit() documentation Refers to fraction of training data (1-test_split) to use for validation.
test_split – Fraction of data to set apart for test model before train-validation split.
validation_batch_size – Number of validation data observations to evaluate evaluation metrics on.
max_validation_steps – Maximum number of validation steps to perform.
patience – refer to tf.keras.models.Model.fit() documentation
lr_schedule_min_lr – Minimum learning rate for learning rate reduction schedule.
lr_schedule_factor – Factor to reduce learning rate by within learning rate reduction schedule when plateau is reached.
lr_schedule_patience – Patience for learning rate reduction in learning rate reduction schedule.
shuffle_buffer_size – tf.Dataset.shuffle(): buffer_size argument.
cache_full – Whether to use tensorflow caching on full training and validation data.
randomized_batch_access – Whether to randomize batches during reading (in generator). Lifts necessity of using a shuffle buffer on generator, however, batch composition stays unchanged over epochs unless there is overhangs in retrieval_batch_size in the raw data files, which often happens and results in modest changes in batch composition.
log_dir – Directory to save tensorboard callback to. Disabled if None.
callbacks – Add additional callbacks to the training call
weighted –
verbose –
- Returns