File size: 1,659 Bytes
8a3fba7
 
 
 
 
bf43b6e
8a3fba7
 
c29126b
8a3fba7
 
cfa2bcb
 
bf43b6e
cfa2bcb
bf43b6e
 
8a3fba7
 
 
 
 
 
 
 
 
 
 
bf43b6e
35d3c36
bf43b6e
35d3c36
8a3fba7
 
 
 
 
bf43b6e
8a3fba7
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

model_session <- R6::R6Class(
  lock_objects = FALSE,
  public = list(
    initialize = function() {
      self$task_q <- NULL
      self$temperature <- 1
      self$top_k <- 50
      self$is_loaded <- NULL
    },
    load_model = function(repo) {
      if (!is.null(self$sess)) {
        cat("Model is already loaded.", "\n")
        return(self$task_q$push(function() "done"))
      }
      self$task_q <- callq::task_q$new(num_workers = 1)
      self$task_q$push(args = list(repo = repo), function(repo) {
        library(torch)
        library(zeallot)
        library(minhub)
        model <<- minhub::gptneox_from_pretrained(repo)
        model$eval()
        model$to(dtype = torch_float())
        tok <<- tok::tokenizer$from_pretrained(repo)
        "done"
      })
    },
    generate = function(prompt) {
      if (is.null(self$task_q)) {
        cat("Model is not loaded, error.", "\n")
        return(self$task_q$push(function() stop("Model is not loaded")))
      }
      args <- list(
        prompt = prompt, 
        temperature = self$temperature,
        top_k = self$top_k
      )
      self$task_q$push(args = args, function(prompt, temperature, top_k) {
        idx <- torch_tensor(tok$encode(prompt)$ids)$view(c(1, -1))
        with_no_grad({
          logits <- model(idx + 1L)
        })
        logits <- logits[,-1,]/temperature
        c(prob, ind) %<-% logits$topk(top_k)
        logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob)
        logits <- nnf_softmax(logits, dim = -1)
        id_next <- torch::torch_multinomial(logits, num_samples = 1) - 1L
        tok$decode(as.integer(id_next))  
      })
    }
  )
)