sfaira.train.SummarizeGridsearchEmbedding.plot_training_history

SummarizeGridsearchEmbedding.plot_training_history(metric_select: str, metric_show: str, partition_select: str = 'val', subset: dict = {}, cv_key: Optional[str] = None, log_loss: bool = False)

Plot train and validation loss during training and learning rate reduction for each organ

The partition that is shown in train+val by default because these are the only ones recorded during training.

Parameters
  • metric_select – metric to select best model by

  • metric_show – metric to show as function of training progress, together with loss and learing rate.

  • partition_select – “train” or “eval” or “test” partition of data to select fit by.

  • metric_select – Metric to select fit by.

  • cv_key – Index of cross-validation to plot training history for.

  • log_loss

Returns