dfalbel commited on
Commit
bf43b6e
1 Parent(s): 25afb76

Simplify to use task queue

Browse files
Files changed (4) hide show
  1. Dockerfile +3 -7
  2. app.R +1 -8
  3. model-session.R +7 -15
  4. promise-session.R +0 -70
Dockerfile CHANGED
@@ -22,13 +22,9 @@ RUN installGithub.r \
22
  mlverse/tok
23
 
24
  RUN installGithub.r \
25
- mlverse/minhub
26
-
27
- RUN installGithub.r \
28
- mlverse/hfhub
29
-
30
- RUN installGithub.r \
31
- mlverse/minhub
32
 
33
  # see: https://huggingface.co/docs/hub/spaces-sdks-docker#permissions
34
  RUN useradd -m -u 1000 user
 
22
  mlverse/tok
23
 
24
  RUN installGithub.r \
25
+ mlverse/minhub \
26
+ mlverse/hfhub \
27
+ mlverse/callq
 
 
 
 
28
 
29
  # see: https://huggingface.co/docs/hub/spaces-sdks-docker#permissions
30
  RUN useradd -m -u 1000 user
app.R CHANGED
@@ -4,16 +4,10 @@ library(minhub)
4
  library(magrittr)
5
  source("model-session.R")
6
 
7
- repo <- "stabilityai/stablelm-tuned-alpha-3b"
8
  repo <- Sys.getenv("MODEL_REPO", unset = repo)
9
  sess <- model_session$new()
10
 
11
- poll_process <- function() {
12
- sess$poll_process(1)
13
- later::later(func = poll_process, delay = 0.5)
14
- }
15
- poll_process()
16
-
17
  max_n_tokens <- 100
18
  system_prompt = "<|SYSTEM|># StableLM Tuned (Alpha version)
19
  - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
@@ -124,7 +118,6 @@ server <- function(input, output, session) {
124
 
125
  # this runs for the cases where sess$is_loaded was NULL
126
  # ie there was no model currently loading.
127
- cat("Loading model:",sess$sess$poll_process(), "\n")
128
  m <- model_loaded() %>%
129
  promises::then(onFulfilled = function(x) {
130
  cat("Model has been loaded!", "\n")
 
4
  library(magrittr)
5
  source("model-session.R")
6
 
7
+ repo <- "EleutherAI/pythia-70m"
8
  repo <- Sys.getenv("MODEL_REPO", unset = repo)
9
  sess <- model_session$new()
10
 
 
 
 
 
 
 
11
  max_n_tokens <- 100
12
  system_prompt = "<|SYSTEM|># StableLM Tuned (Alpha version)
13
  - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
 
118
 
119
  # this runs for the cases where sess$is_loaded was NULL
120
  # ie there was no model currently loading.
 
121
  m <- model_loaded() %>%
122
  promises::then(onFulfilled = function(x) {
123
  cat("Model has been loaded!", "\n")
model-session.R CHANGED
@@ -1,28 +1,20 @@
1
- source("promise-session.R")
2
 
3
- # A wrapper a around the promise session that controls model loading and
4
- # querying given a prompt
5
  model_session <- R6::R6Class(
6
  lock_objects = FALSE,
7
  public = list(
8
  initialize = function() {
9
- self$sess <- NULL
10
  self$temperature <- 1
11
  self$top_k <- 50
12
  self$is_loaded <- NULL
13
  },
14
- poll_process = function(timeout = 1) {
15
- if (!is.null(self$sess)) {
16
- self$sess$poll_process(timeout)
17
- }
18
- },
19
  load_model = function(repo) {
20
  if (!is.null(self$sess)) {
21
  cat("Model is already loaded.", "\n")
22
- return(self$sess$call(function() "done"))
23
  }
24
- self$sess <- promise_session$new()
25
- self$sess$call(args = list(repo = repo), function(repo) {
26
  library(torch)
27
  library(zeallot)
28
  library(minhub)
@@ -34,16 +26,16 @@ model_session <- R6::R6Class(
34
  })
35
  },
36
  generate = function(prompt) {
37
- if (is.null(self$sess)) {
38
  cat("Model is not loaded, error.", "\n")
39
- return(self$sess$call(function() stop("Model is not loaded")))
40
  }
41
  args <- list(
42
  prompt = prompt,
43
  temperature = self$temperature,
44
  top_k = self$top_k
45
  )
46
- self$sess$call(args = args, function(prompt, temperature, top_k) {
47
  idx <- torch_tensor(tok$encode(prompt)$ids)$view(c(1, -1))
48
  with_no_grad({
49
  logits <- model(idx + 1L)
 
 
1
 
 
 
2
  model_session <- R6::R6Class(
3
  lock_objects = FALSE,
4
  public = list(
5
  initialize = function() {
6
+ self$task_q <- NULL
7
  self$temperature <- 1
8
  self$top_k <- 50
9
  self$is_loaded <- NULL
10
  },
 
 
 
 
 
11
  load_model = function(repo) {
12
  if (!is.null(self$sess)) {
13
  cat("Model is already loaded.", "\n")
14
+ return(self$task_q$push(function() "done"))
15
  }
16
+ self$task_q <- callq::task_q$new(num_workers = 1)
17
+ self$task_q$push(args = list(repo = repo), function(repo) {
18
  library(torch)
19
  library(zeallot)
20
  library(minhub)
 
26
  })
27
  },
28
  generate = function(prompt) {
29
+ if (is.null(self$task_q)) {
30
  cat("Model is not loaded, error.", "\n")
31
+ return(self$task_q$push(function() stop("Model is not loaded")))
32
  }
33
  args <- list(
34
  prompt = prompt,
35
  temperature = self$temperature,
36
  top_k = self$top_k
37
  )
38
+ self$task_q$push(args = args, function(prompt, temperature, top_k) {
39
  idx <- torch_tensor(tok$encode(prompt)$ids)$view(c(1, -1))
40
  with_no_grad({
41
  logits <- model(idx + 1L)
promise-session.R DELETED
@@ -1,70 +0,0 @@
1
-
2
- # Small utility class that wraps a `callr::r_session` to return promises when
3
- # executing `sess$call()`.
4
- # Only one promise is resolve per time in fifo way.
5
- promise_session <- R6::R6Class(
6
- lock_objects = FALSE,
7
- public = list(
8
- initialize = function() {
9
- self$sess <- callr::r_session$new()
10
- self$is_running <- FALSE
11
- },
12
- call = function(func, args = list()) {
13
- self$poll_process()
14
- promises::promise(function(resolve, reject) {
15
- self$push_task(func, args, resolve, reject)
16
- later::later(self$poll_process, 1)
17
- })
18
- },
19
- push_task = function(func, args, resolve, reject) {
20
- self$tasks[[length(self$tasks) + 1]] <- list(
21
- func = func,
22
- args = args,
23
- resolve = resolve,
24
- reject = reject
25
- )
26
- cat("task pushed, now we have ", length(self$tasks), " on queue\n")
27
- self$run_task()
28
- invisible(NULL)
29
- },
30
- run_task = function() {
31
- if (self$is_running) return(NULL)
32
- if (length(self$tasks) == 0) return(NULL)
33
-
34
- self$is_running <- TRUE
35
- task <- self$tasks[[1]]
36
- self$sess$call(task$func, args = task$args)
37
- },
38
- resolve_task = function() {
39
- cat("Resolving task! ")
40
- out <- self$sess$read()
41
- if (!is.null(out$error)) {
42
- self$tasks[[1]]$reject(out$error)
43
- } else {
44
- self$tasks[[1]]$resolve(out$result)
45
- }
46
-
47
- self$tasks <- self$tasks[-1]
48
- cat("now we have ", length(self$tasks), "on queue\n")
49
-
50
- self$is_running <- FALSE
51
-
52
- self$run_task()
53
- },
54
- poll_process = function(timeout = 1) {
55
- if (!self$is_running) return("ready")
56
- poll_state <- self$sess$poll_process(timeout)
57
- if (poll_state == "ready") {
58
- self$resolve_task()
59
- }
60
- poll_state
61
- }
62
- )
63
- )
64
-
65
- # sess <- promise_session$new()
66
- # f <- sess$call(function(a) {
67
- # 10 + 1
68
- # }, list(1))
69
- # sess$poll_process()
70
-