You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mamlr/R/cv_generator.R

63 lines
3.1 KiB

#' Generate CV folds for nested cross-validation
#'
#' Creates a grid of models to be estimated for each outer fold, inner fold and parameter combination
#'
#' @param outer_k Number of outer CV (performance estimation) folds. If outer_k < 1 holdout sampling is used, with outer_k being the amount of test data
#' @param inner_k Number of inner CV (parameter optimization) folds
#' @param vec Vector containing the true values of the classification
#' @param grid Parameter grid for optimization
#' @param seed integer used as seed for random number generation
#' @return A nested set of lists with row numbers
#' @export
#' @examples
#' cv_generator(outer_k, inner_k, dfm, class_type)
#################################################################################################
#################################### Generate CV folds ##########################################
#################################################################################################
cv_generator <- function(outer_k, inner_k, vec, grid, seed) {
### Generate inner folds for nested cv
inner_loop <- function(i, folds, vec, inner_k, grid, seed) {
# RNG needs to be set explicitly for each fold
set.seed(seed, kind = "Mersenne-Twister", normal.kind = "Inversion")
inner_folds <- createFolds(as.factor(vec[-folds[[i]]]), k= inner_k)
grid <- crossing(grid, inner_fold = names(inner_folds), outer_fold = names(folds)[i])
return(list(grid = grid, inner_folds = inner_folds, outer_fold = names(folds)[i]))
}
### Generate outer folds for nested cv
generate_folds <- function(outer_k, inner_k, vec, grid, seed){
set.seed(seed, kind = "Mersenne-Twister", normal.kind = "Inversion")
if (is.null(outer_k)) { # If no outer_k, use all data to generate inner_k folds for parameter optimization
inner_folds <- createFolds(as.factor(vec), k= inner_k)
grid <- crossing(grid, inner_fold = names(inner_folds))
return(list(grid = grid,
inner_folds = inner_folds))
} else if (outer_k < 1) { # Create holdout validation for model performance estimation, with test set equal to outer_k
folds <- createDataPartition(as.factor(vec), p=outer_k)
} else { # Do full nested CV
folds <- createFolds(as.factor(vec), k= outer_k)
}
# Generate grid of hyperparameters for model optimization, and include inner folds row numbers
grid_folds <- lapply(1:length(folds),
inner_loop,
folds = folds,
vec = vec,
inner_k = inner_k,
grid = grid,
seed = seed)
# Extract grid dataframe from results
grid <- grid_folds %>% purrr::map(1) %>% dplyr::bind_rows()
# Extract row numbers for inner folds from results
inner_folds <- grid_folds %>% purrr::map(2)
# Extract the names of the inner folds from results
names(inner_folds) <- grid_folds %>% purrr::map(3) %>% unlist(.)
return(list(grid = grid,
outer_folds = folds,
inner_folds = inner_folds))
}
return(generate_folds(outer_k,inner_k = inner_k, vec = vec, grid = grid, seed = seed))
}