knitr::opts_chunk$set(echo = TRUE)
suppressPackageStartupMessages({
library(data.table)
library(dplyr)
library(this.path)
library(stringplus)
library(qs2)
library(ggplot2)
})
now <- function() assign(".time", Sys.time(), envir = globalenv())
later <- function() { as.numeric(Sys.time() - get(".time", envir = globalenv()), units = "secs") }
# Datasets for benchmarking
DATA_PATH <- "~/datasets/processed"
read_timer_script <- "blocksize_benchmark_read_timer_function.R"
SAVE_PATH <- tempfile()
RESULTS_PATH <- tempfile()
PLATFORM <- "ubuntu" # for caching and results name purposes

# training datasets
datasets <- DATA_PATH & "/" & c("DC_real_estate_June_2024.json.gz", "dslabs_mnist.rds", "enwik8.csv.gz", 
              "era5_land_wind_20230101.rds", "GAIA_pseudocolor.csv.gz", "NYSE_1962_2024.csv.gz", 
              "recount3_gtex_heart.rds", "T_cell_ADIRP0000010.rds")

read_dataset <- function(d) {
  if(d %like% "json.gz") {
    DATA <- RcppSimdJson::fload(d)
  } else if(d %like% ".csv.gz") {
    DATA <- fread(d, sep = ",", data.table=F)
  } else {
    DATA <- readRDS(d)
  }
}
outfile <- PLATFORM & "_blocksize_optimization_data.csv.gz"
if(!file.exists(outfile)) {

  BLOCKSIZES <- as.integer(2^seq(17, 24)) %>% sort
  grid <- expand.grid(cl = c(3, 9),
                      blocksize = BLOCKSIZES,
                      algo = c("qdata"),
                      nt = c(1,4),
                      rep = 1:5, stringsAsFactors = FALSE)
  
  results <- lapply(datasets, function(d) {
    DATA <- read_dataset(d)
    grid <- sample_frac(grid, 1)
    lapply(1:nrow(grid), function(i) {
      cl <- grid$cl[i]
      b <- grid$blocksize[i]
      algo <- grid$algo[i]
      nt <- grid$nt[i]
      rep <- grid$rep[i]
      
      qs2:::internal_set_blocksize(b)
      now()
      if(algo == "qs2") {
        qs_save(DATA, file = SAVE_PATH, compress_level = cl, nthreads = nt, shuffle = FALSE)
      } else {
        qd_save(DATA, file = SAVE_PATH, compress_level = cl, nthreads = nt, shuffle = FALSE)
      }
      save_time <- later()
      
      # read timer
      system("Rscript {path} {algo} {blocksize} {nthreads} {output} {results}" | 
                list(path=read_timer_script, algo=algo, blocksize=b, nthreads=nt, output=SAVE_PATH, results=RESULTS_PATH))
      read_time <- readLines(RESULTS_PATH) %>% as.numeric
      
      file_size <- file.info(SAVE_PATH)[1,"size"] / 1048576
      grid[i,] %>% mutate(save_time = save_time, read_time = read_time, file_size = file_size, 
                          dataset = basename(d) %>% gsub("\\..+$", "", .))
    }) %>% rbindlist
  }) %>% rbindlist
  
  fwrite(results, file = outfile, sep = ",")
} else {
  results <- fread(outfile, data.table=FALSE)
}
results <- results %>%
  group_by(cl, blocksize, nt, dataset) %>%
  summarize(save_time = median(save_time), read_time=median(read_time), file_size = max(file_size))
## `summarise()` has grouped output by 'cl', 'blocksize', 'nt'. You can override
## using the `.groups` argument.
ggplot(results, aes(y = save_time, x = file_size, color = factor(cl), lty = factor(nt))) + 
  # geom_point(color = "blue") + 
  geom_line() + 
  # geom_text_repel(aes(label = log2(blocksize) %>% signif(3) %>% as.character), 
  #                 size = 2, min.segment.length=Inf, box.padding = 0.1, force_pull = 2, color = "black") +
  geom_text(aes(label = log2(blocksize) %>% signif(3) %>% as.character),
                  size = 2, color = "black") +
  scale_x_log10() +
  scale_y_log10(n.breaks = 8) + 
  facet_wrap(cl~dataset, scales = "free", ncol = 4) + 
  theme_bw(base_size=11) + 
  theme(legend.position = "bottom") + 
  labs(color = "compress level", lty = "nthreads")

ggplot(results, aes(y = read_time, x = file_size, color = factor(cl), lty = factor(nt))) + 
  # geom_point(color = "blue") + 
  geom_line() + 
  # geom_text_repel(aes(label = log2(blocksize) %>% signif(3) %>% as.character), 
  #                 size = 2, min.segment.length=Inf, box.padding = 0.1, force_pull = 2, color = "black") +
  geom_text(aes(label = log2(blocksize) %>% signif(3) %>% as.character),
            size = 2, color = "black") +
  scale_x_log10() +
  scale_y_log10(n.breaks = 8) + 
  facet_wrap(cl~dataset, scales = "free", ncol = 4) + 
  theme_bw(base_size=11) + 
  theme(legend.position = "bottom") + 
  labs(color = "compress level", lty = "nthreads")