Spaces:
Runtime error
Runtime error
File size: 1,659 Bytes
8a3fba7 bf43b6e 8a3fba7 c29126b 8a3fba7 cfa2bcb bf43b6e cfa2bcb bf43b6e 8a3fba7 bf43b6e 35d3c36 bf43b6e 35d3c36 8a3fba7 bf43b6e 8a3fba7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
model_session <- R6::R6Class(
lock_objects = FALSE,
public = list(
initialize = function() {
self$task_q <- NULL
self$temperature <- 1
self$top_k <- 50
self$is_loaded <- NULL
},
load_model = function(repo) {
if (!is.null(self$sess)) {
cat("Model is already loaded.", "\n")
return(self$task_q$push(function() "done"))
}
self$task_q <- callq::task_q$new(num_workers = 1)
self$task_q$push(args = list(repo = repo), function(repo) {
library(torch)
library(zeallot)
library(minhub)
model <<- minhub::gptneox_from_pretrained(repo)
model$eval()
model$to(dtype = torch_float())
tok <<- tok::tokenizer$from_pretrained(repo)
"done"
})
},
generate = function(prompt) {
if (is.null(self$task_q)) {
cat("Model is not loaded, error.", "\n")
return(self$task_q$push(function() stop("Model is not loaded")))
}
args <- list(
prompt = prompt,
temperature = self$temperature,
top_k = self$top_k
)
self$task_q$push(args = args, function(prompt, temperature, top_k) {
idx <- torch_tensor(tok$encode(prompt)$ids)$view(c(1, -1))
with_no_grad({
logits <- model(idx + 1L)
})
logits <- logits[,-1,]/temperature
c(prob, ind) %<-% logits$topk(top_k)
logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob)
logits <- nnf_softmax(logits, dim = -1)
id_next <- torch::torch_multinomial(logits, num_samples = 1) - 1L
tok$decode(as.integer(id_next))
})
}
)
)
|