sfaira.train.SummarizeGridsearchEmbedding

class sfaira.train.SummarizeGridsearchEmbedding(source_path: dict, cv: bool, loss_idx: int = 0, mse_idx: int = 1, model_id_len: int = 3)

Attributes

List

alias of List

Union

cv_keys

Returns keys of cross-validation used in dictionaries in this class.

loss_idx

mse_idx

Methods

best_model_by_partition(partition_select, ...)

param partition_select

best_model_embedding([subset, partition, ...])

create_summary_tab()

get_best_model_ids(tab, metric_select, ...)

param tab

get_gradients_by_celltype(model_organ, ...)

Compute gradients across latent units with respect to input features for each cell type.

load_gs(gs_ids)

Loads all relevant data of a grid search.

load_y(hat_or_true, run_id)

plot_active_latent_units(organ, topology_version)

Plots latent unit activity measured by empirical variance of the expected latent space.

plot_best([rename_levels, partition_select, ...])

param rename_levels

plot_best_model_by_hyperparam(metric_select)

Produces boxplots for all hyperparameters choices by organ.

plot_completions([groupby, height_fig, ...])

Plot number of completed grid search points by category.

plot_gradient_cor(model_organ, data_organ, ...)

Plot correlation heatmap of gradient vectors accumulated on input features between cell types or models.

plot_gradient_distr(model_organ, data_organ, ...)

plot_npc(organ, topology_version[, cvs])

Plots the explained variance ration that accumulates explained variation of the latent space’s ordered principal components.

plot_training_history(metric_select, metric_show)

Plot train and validation loss during training and learning rate reduction for each organ

save_best_weight(path[, partition, metric, ...])

Copies weight file from best hyperparameter setting from grid search directory to zoo directory with cleaned file name.

write_best_hyparam(write_path[, subset, ...])