sfaira.train.SummarizeGridsearchCelltype¶
- class sfaira.train.SummarizeGridsearchCelltype(source_path: dict, cv: bool, model_id_len: int = 3)¶
Attributes
Returns keys of cross-validation used in dictionaries in this class.
Methods
best_model_by_partition(partition_select, ...)- param partition_select
best_model_celltype([subset, partition, ...])get_best_model_ids(tab, metric_select, ...)- param tab
load_gs(gs_ids)Loads all relevant data of a grid search.
load_ontology_names(run_id)Loads ontology ids from a specific model of a previously loaded grid search.
load_y(hat_or_true, run_id)plot_best([rename_levels, partition_select, ...])Plot accuracy or other metric heatmap by organ and model type.
plot_best_classwise_heatmap(organ, organism, ...)Plot evaluation metric heatmap for specified organ by cell classes and model types.
plot_best_classwise_scatter(organ, organism, ...)Plot evaluation metric scatterplot for specified organ by cell classes and model types.
plot_best_model_by_hyperparam(metric_select)Produces boxplots for all hyperparameters choices by organ.
plot_completions([groupby, height_fig, ...])Plot number of completed grid search points by category.
plot_training_history(metric_select, metric_show)Plot train and validation loss during training and learning rate reduction for each organ
save_best_weight(path[, partition, metric, ...])Copies weight file from best hyperparameter setting from grid search directory to zoo directory with cleaned file name.
write_best_hyparam(write_path[, subset, ...])