sfaira.estimators.EstimatorKeras.train

EstimatorKeras.train(optimizer: str, lr: float, epochs: int = 1000, max_steps_per_epoch: Optional[int] = 20, batch_size: int = 128, validation_split: float = 0.1, test_split: Union[float, dict] = 0.0, validation_batch_size: int = 256, max_validation_steps: Optional[int] = 10, patience: int = 20, lr_schedule_min_lr: float = 1e-05, lr_schedule_factor: float = 0.2, lr_schedule_patience: int = 5, shuffle_buffer_size: Optional[int] = None, cache_full: bool = False, randomized_batch_access: bool = True, retrieval_batch_size: int = 512, prefetch: Optional[int] = 1, log_dir: Optional[str] = None, callbacks: Optional[list] = None, weighted: bool = False, verbose: int = 2)

Train model.

Uses validation loss and maximum number of epochs as termination criteria.

Parameters
  • optimizer – str corresponding to tf.keras optimizer to use for fitting.

  • lr – Learning rate

  • epochs – refer to tf.keras.models.Model.fit() documentation

  • max_steps_per_epoch – Maximum steps per epoch.

  • batch_size – refer to tf.keras.models.Model.fit() documentation

  • validation_split – refer to tf.keras.models.Model.fit() documentation Refers to fraction of training data (1-test_split) to use for validation.

  • test_split – Fraction of data to set apart for test model before train-validation split.

  • validation_batch_size – Number of validation data observations to evaluate evaluation metrics on.

  • max_validation_steps – Maximum number of validation steps to perform.

  • patience – refer to tf.keras.models.Model.fit() documentation

  • lr_schedule_min_lr – Minimum learning rate for learning rate reduction schedule.

  • lr_schedule_factor – Factor to reduce learning rate by within learning rate reduction schedule when plateau is reached.

  • lr_schedule_patience – Patience for learning rate reduction in learning rate reduction schedule.

  • shuffle_buffer_size – tf.Dataset.shuffle(): buffer_size argument.

  • cache_full – Whether to use tensorflow caching on full training and validation data.

  • randomized_batch_access – Whether to randomize batches during reading (in generator). Lifts necessity of using a shuffle buffer on generator, however, batch composition stays unchanged over epochs unless there is overhangs in retrieval_batch_size in the raw data files, which often happens and results in modest changes in batch composition.

  • log_dir – Directory to save tensorboard callback to. Disabled if None.

  • callbacks – Add additional callbacks to the training call

  • weighted

  • verbose

Returns