Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export(get_grouping)
export(get_metrics)
export(get_pairwise_comparisons)
export(get_pit_histogram)
export(get_unique_values)
export(interval_coverage)
export(is_forecast)
export(is_forecast_binary)
Expand Down
69 changes: 69 additions & 0 deletions R/get-unique-values.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#' @title Get unique value counts per forecast unit column
#'
#' @description
#' Given a forecast object, this function counts the number of unique values
#' in each column that defines the forecast unit (as determined by
#' [get_forecast_unit()]). This is useful for getting a quick overview of the
#' structure and scope of a forecast dataset.
#'
#' @param by character vector (default: `NULL`). If specified, results are
#' grouped by this column, showing unique value counts for each forecast
#' unit column per group. The grouping column itself is excluded from the
#' column-level counts.
#'
#' @returns A data.table with columns `"column"` (the name of the forecast unit
#' column) and `"N_unique"` (the number of unique values in that column).
#' If `by` is specified, there is an additional column for the grouping
#' variable.
#'
#' @inheritParams score
#' @importFrom checkmate assert_subset
#' @export
#' @keywords gain-insights
#' @examples
#' \dontshow{
#' data.table::setDTthreads(2) # restricts number of cores used on CRAN
#' }
#'
#' example_quantile |>
#' as_forecast_quantile() |>
#' get_unique_values()
#'
#' example_quantile |>
#' as_forecast_quantile() |>
#' get_unique_values(by = "model")
get_unique_values <- function(forecast, by = NULL) {
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast_unit <- get_forecast_unit(forecast)

if (!is.null(by)) {
assert_subset(by, forecast_unit, empty.ok = FALSE)
}

forecast <- as.data.table(forecast)

cols <- if (!is.null(by)) setdiff(forecast_unit, by) else forecast_unit

if (is.null(by)) {
out <- data.table(
column = cols,
N_unique = vapply(
cols,
function(col) length(unique(forecast[[col]])),
integer(1)
)
)
} else {
out_list <- forecast[,
{
vals <- lapply(cols, function(col) length(unique(.SD[[col]])))
list(column = cols, N_unique = unlist(vals))
},
by = by,
.SDcols = cols
]
out <- out_list
}

return(out[])
}
1 change: 1 addition & 0 deletions R/z-globalVariables.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ globalVariables(c(
"metrics",
"metrics_select",
"model",
"N_unique",
"n_obs",
"n_obs wis_component_name",
"observed",
Expand Down
43 changes: 43 additions & 0 deletions man/get_unique_values.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

125 changes: 125 additions & 0 deletions tests/testthat/test-get-unique-values.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# ==============================================================================
# `get_unique_values()` # nolint: commented_code_linter
# ==============================================================================
test_that("get_unique_values() works with a quantile forecast", {
forecast <- suppressMessages(as_forecast_quantile(example_quantile))
result <- get_unique_values(forecast)

expect_s3_class(result, c("data.table", "data.frame"), exact = TRUE)

# Should have a column for the column name and one for the unique count
expect_true("column" %in% names(result) || "Column" %in% names(result) ||
length(names(result)) == 2)
expect_true(ncol(result) == 2)

Check warning on line 13 in tests/testthat/test-get-unique-values.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=tests/testthat/test-get-unique-values.R,line=13,col=3,[expect_comparison_linter] expect_identical(x, y) is better than expect_true(x == y).

# The column names in the result should be exactly the forecast unit columns
fu <- get_forecast_unit(forecast)
col_col <- names(result)[1]
expect_setequal(result[[col_col]], fu)

# Check some expected unique value counts (after clean_forecast with na.omit)
count_col <- names(result)[2]
location_count <- result[result[[col_col]] == "location", ][[count_col]]
expect_identical(location_count, 4L)
target_type_count <- result[result[[col_col]] == "target_type", ][[count_col]]
expect_identical(target_type_count, 2L)
location_name_count <- result[
result[[col_col]] == "location_name",
][[count_col]]
expect_identical(location_name_count, 4L)
})

test_that("get_unique_values() works with different forecast types", {
forecast_binary <- suppressMessages(as_forecast_binary(example_binary))
result_binary <- get_unique_values(forecast_binary)

expect_s3_class(result_binary, c("data.table", "data.frame"), exact = TRUE)

# Should only contain forecast unit columns, not protected columns
col_col <- names(result_binary)[1]
expect_false("predicted" %in% result_binary[[col_col]])
expect_false("observed" %in% result_binary[[col_col]])

forecast_sample <- suppressMessages(
as_forecast_sample(example_sample_continuous)
)
result_sample <- get_unique_values(forecast_sample)

expect_s3_class(result_sample, c("data.table", "data.frame"), exact = TRUE)

# sample_id should not appear (it is a protected column)
col_col_s <- names(result_sample)[1]
expect_false("sample_id" %in% result_sample[[col_col_s]])
})

test_that("get_unique_values() returns correct output structure", {
dt <- data.table::data.table(
location = c("A", "A", "B"),
model = c("m1", "m2", "m1"),
observed = c(1, 2, 3),
predicted = c(1.1, 2.1, 3.1)
)
forecast <- suppressMessages(as_forecast_point(dt))
result <- get_unique_values(forecast)

expect_s3_class(result, c("data.table", "data.frame"), exact = TRUE)

col_col <- names(result)[1]
count_col <- names(result)[2]

# Should have exactly 2 rows: location and model
expect_identical(nrow(result), 2L)
expect_setequal(result[[col_col]], c("location", "model"))

# observed and predicted should NOT appear
expect_false("observed" %in% result[[col_col]])
expect_false("predicted" %in% result[[col_col]])

# Unique counts
location_count <- result[result[[col_col]] == "location", ][[count_col]]
model_count <- result[result[[col_col]] == "model", ][[count_col]]
expect_identical(location_count, 2L)
expect_identical(model_count, 2L)
})

test_that("get_unique_values() accepts a `by` argument for grouping", {
forecast <- suppressMessages(as_forecast_quantile(example_quantile))
result <- get_unique_values(forecast, by = "model")

expect_s3_class(result, c("data.table", "data.frame"), exact = TRUE)

# Should have a model column in the result

expect_true("model" %in% names(result))

# "model" should not appear in the column-name column since it's the

# grouping variable
col_col <- setdiff(names(result), c("model", names(result)[ncol(result)]))[1]
if (is.na(col_col)) {
col_col <- names(result)[1]
}

# Should have rows for each model x column combination
# After na.omit there are 4 models
fu <- setdiff(get_forecast_unit(forecast), "model")
n_models <- length(unique(
scoringutils:::clean_forecast(forecast, copy = TRUE, na.omit = TRUE)$model
))
expect_gte(nrow(result), n_models)
})

test_that("get_unique_values() errors on non-forecast input", {
expect_error(get_unique_values("not a forecast"))
expect_error(get_unique_values(42))
})

test_that("get_unique_values() handles data with NAs correctly", {
dt <- data.table::copy(example_quantile)
dt[1:5, location := NA]
forecast <- suppressMessages(as_forecast_quantile(dt))

# Should not error
expect_no_error(result <- get_unique_values(forecast))
expect_s3_class(result, c("data.table", "data.frame"), exact = TRUE)
})
Loading