You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mamlr/R/cv_generator.R

65 lines
3.4 KiB

#' Generate CV folds for nested cross-validation
#'
#' Creates a grid of models to be estimated for each outer fold, inner fold and parameter combination
#'
#' @param outer_k Number of outer CV (performance estimation) folds. If outer_k < 1 holdout sampling is used, with outer_k being the amount of test data
#' @param inner_k Number of inner CV (parameter optimization) folds
#' @param dfm DFM containing the labeled documents
#' @param class_type Name of the column in docvars containing the classification
#' @param grid Parameter grid for optimization
#' @param seed integer used as seed for random number generation
#' @return A nested set of lists with row numbers
#' @export
#' @examples
#' cv_generator(outer_k, inner_k, dfm, class_type)
#################################################################################################
#################################### Generate CV folds ##########################################
#################################################################################################
cv_generator <- function(outer_k, inner_k, dfm, class_type, grid, seed) {
### Generate inner folds for nested cv
inner_loop <- function(i, folds, dfm, inner_k, class_type, grid, seed) {
# RNG needs to be set explicitly for each fold
set.seed(seed, kind = "Mersenne-Twister", normal.kind = "Inversion")
inner_folds <- createFolds(as.factor(docvars(dfm[-folds[[i]],], class_type)), k= inner_k)
grid <- crossing(grid, inner_fold = names(inner_folds), outer_fold = names(folds)[i])
return(list(grid = grid, inner_folds = inner_folds, outer_fold = names(folds)[i]))
}
### Generate outer folds for nested cv
generate_folds <- function(outer_k, inner_k, dfm, class_type, grid, seed){
set.seed(seed, kind = "Mersenne-Twister", normal.kind = "Inversion")
if (is.null(outer_k)) { # If no outer_k, use all data to generate inner_k folds for parameter optimization
inner_folds <- createFolds(as.factor(docvars(dfm, class_type)), k= inner_k)
grid <- crossing(grid, inner_fold = names(inner_folds))
return(list(grid = grid,
inner_folds = inner_folds))
} else if (outer_k < 1) { # Create holdout validation for model performance estimation, with test set equal to outer_k
folds <- createDataPartition(as.factor(docvars(dfm, class_type)), p=outer_k)
} else { # Do full nested CV
folds <- createFolds(as.factor(docvars(dfm, class_type)), k= outer_k)
}
# Generate grid of hyperparameters for model optimization, and include inner folds row numbers
grid_folds <- lapply(1:length(folds),
inner_loop,
folds = folds,
dfm = dfm,
inner_k = inner_k,
class_type = class_type,
grid = grid,
seed = seed)
# Extract grid dataframe from results
grid <- grid_folds %>% purrr::map(1) %>% dplyr::bind_rows()
# Extract row numbers for inner folds from results
inner_folds <- grid_folds %>% purrr::map(2)
# Extract the names of the inner folds from results
names(inner_folds) <- grid_folds %>% purrr::map(3) %>% unlist(.)
return(list(grid = grid,
outer_folds = folds,
inner_folds = inner_folds))
}
return(generate_folds(outer_k,inner_k = inner_k, dfm = dfm, class_type = class_type, grid = grid, seed = seed))
}