model_session <- R6::R6Class( lock_objects = FALSE, public = list( initialize = function() { self$task_q <- NULL self$temperature <- 1 self$top_k <- 50 self$is_loaded <- NULL }, load_model = function(repo) { if (!is.null(self$sess)) { cat("Model is already loaded.", "\n") return(self$task_q$push(function() "done")) } self$task_q <- callq::task_q$new(num_workers = 1) self$task_q$push(args = list(repo = repo), function(repo) { library(torch) library(zeallot) library(minhub) model <<- minhub::gptneox_from_pretrained(repo) model$eval() model$to(dtype = torch_float()) tok <<- tok::tokenizer$from_pretrained(repo) "done" }) }, generate = function(prompt) { if (is.null(self$task_q)) { cat("Model is not loaded, error.", "\n") return(self$task_q$push(function() stop("Model is not loaded"))) } args <- list( prompt = prompt, temperature = self$temperature, top_k = self$top_k ) self$task_q$push(args = args, function(prompt, temperature, top_k) { idx <- torch_tensor(tok$encode(prompt)$ids)$view(c(1, -1)) with_no_grad({ logits <- model(idx + 1L) }) logits <- logits[,-1,]/temperature c(prob, ind) %<-% logits$topk(top_k) logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob) logits <- nnf_softmax(logits, dim = -1) id_next <- torch::torch_multinomial(logits, num_samples = 1) - 1L tok$decode(as.integer(id_next)) }) } ) )