Spaces:
Runtime error
Runtime error
model_session <- R6::R6Class( | |
lock_objects = FALSE, | |
public = list( | |
initialize = function() { | |
self$task_q <- NULL | |
self$temperature <- 1 | |
self$top_k <- 50 | |
self$is_loaded <- NULL | |
}, | |
load_model = function(repo) { | |
if (!is.null(self$sess)) { | |
cat("Model is already loaded.", "\n") | |
return(self$task_q$push(function() "done")) | |
} | |
self$task_q <- callq::task_q$new(num_workers = 1) | |
self$task_q$push(args = list(repo = repo), function(repo) { | |
library(torch) | |
library(zeallot) | |
library(minhub) | |
device <- if (cuda_is_available()) "cuda" else "cpu" | |
model <<- minhub::gptneox_from_pretrained(repo) | |
model$eval() | |
if (device == "cuda") { | |
model$to(device=device) | |
#model$to(dtype=torch_float()) | |
} else { | |
model$to(dtype = torch_float()) | |
} | |
tok <<- tok::tokenizer$from_pretrained(repo) | |
"done" | |
}) | |
}, | |
generate = function(prompt) { | |
if (is.null(self$task_q)) { | |
cat("Model is not loaded, error.", "\n") | |
return(self$task_q$push(function() stop("Model is not loaded"))) | |
} | |
args <- list( | |
prompt = prompt, | |
temperature = self$temperature, | |
top_k = self$top_k | |
) | |
self$task_q$push(args = args, function(prompt, temperature, top_k) { | |
device <- if (cuda_is_available()) "cuda" else "cpu" | |
idx <- torch_tensor(tok$encode(prompt)$ids, device=device)$view(c(1, -1)) | |
with_no_grad({ | |
logits <- model(idx + 1L)$to(dtype="float", device="cpu") | |
}) | |
logits <- logits[,-1,]/temperature | |
c(prob, ind) %<-% logits$topk(top_k) | |
logits <- torch_full_like(logits, -1e7)$scatter_(-1, ind, prob) | |
logits <- nnf_softmax(logits, dim = -1) | |
id_next <- torch::torch_multinomial(logits, num_samples = 1)$cpu() - 1L | |
tok$decode(as.integer(id_next)) | |
}) | |
} | |
) | |
) | |