model_session <- R6::R6Class( lock_objects = FALSE, public = list( initialize = function() { self$task_q <- NULL self$temperature <- 1 self$top_k <- 50 self$is_loaded <- NULL }, load_model = function(repo) { if (!is.null(self$sess)) { cat("Model is already loaded.", "\n") return(self$task_q$push(function() "done")) } # the tokenizer doesn't need to live in the remote session. self$tok <- tok::tokenizer$from_pretrained(repo) self$task_q <- callq::task_q$new(num_workers = 1) self$task_q$push(args = list(repo = repo), function(repo) { library(torch) library(zeallot) library(minhub) device <- if (cuda_is_available()) "cuda" else "cpu" model <<- minhub::gptneox_from_pretrained(repo) model$eval() if (device == "cuda") { model$to(dtype=torch_half()) model$to(device=device) } else { model$to(dtype = torch_float()) } "done" }) }, generate = function(idx) { if (is.null(self$task_q)) { cat("Model is not loaded, error.", "\n") return(self$task_q$push(function() stop("Model is not loaded"))) } args <- list( idx = idx, temperature = self$temperature, top_k = self$top_k ) self$task_q$push(args = args, function(idx, temperature, top_k) { device <- if (cuda_is_available()) "cuda" else "cpu" idx <- torch_tensor(idx, device=device)$view(c(1, -1)) with_no_grad({ logits <- model(idx + 1L)$to(dtype="float", device="cpu") }) logits <- logits[,-1,]/temperature c(prob, ind) %<-% logits$topk(top_k) logits <- torch_full_like(logits, -1e7)$scatter_(-1, ind, prob) logits <- nnf_softmax(logits, dim = -1) id_next <- torch::torch_multinomial(logits, num_samples = 1)$cpu() - 1L as.integer(id_next) }) } ) )