Spaces:
Runtime error
Runtime error
Simplify to use task queue
Browse files- Dockerfile +3 -7
- app.R +1 -8
- model-session.R +7 -15
- promise-session.R +0 -70
Dockerfile
CHANGED
@@ -22,13 +22,9 @@ RUN installGithub.r \
|
|
22 |
mlverse/tok
|
23 |
|
24 |
RUN installGithub.r \
|
25 |
-
mlverse/minhub
|
26 |
-
|
27 |
-
|
28 |
-
mlverse/hfhub
|
29 |
-
|
30 |
-
RUN installGithub.r \
|
31 |
-
mlverse/minhub
|
32 |
|
33 |
# see: https://huggingface.co/docs/hub/spaces-sdks-docker#permissions
|
34 |
RUN useradd -m -u 1000 user
|
|
|
22 |
mlverse/tok
|
23 |
|
24 |
RUN installGithub.r \
|
25 |
+
mlverse/minhub \
|
26 |
+
mlverse/hfhub \
|
27 |
+
mlverse/callq
|
|
|
|
|
|
|
|
|
28 |
|
29 |
# see: https://huggingface.co/docs/hub/spaces-sdks-docker#permissions
|
30 |
RUN useradd -m -u 1000 user
|
app.R
CHANGED
@@ -4,16 +4,10 @@ library(minhub)
|
|
4 |
library(magrittr)
|
5 |
source("model-session.R")
|
6 |
|
7 |
-
repo <- "
|
8 |
repo <- Sys.getenv("MODEL_REPO", unset = repo)
|
9 |
sess <- model_session$new()
|
10 |
|
11 |
-
poll_process <- function() {
|
12 |
-
sess$poll_process(1)
|
13 |
-
later::later(func = poll_process, delay = 0.5)
|
14 |
-
}
|
15 |
-
poll_process()
|
16 |
-
|
17 |
max_n_tokens <- 100
|
18 |
system_prompt = "<|SYSTEM|># StableLM Tuned (Alpha version)
|
19 |
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
@@ -124,7 +118,6 @@ server <- function(input, output, session) {
|
|
124 |
|
125 |
# this runs for the cases where sess$is_loaded was NULL
|
126 |
# ie there was no model currently loading.
|
127 |
-
cat("Loading model:",sess$sess$poll_process(), "\n")
|
128 |
m <- model_loaded() %>%
|
129 |
promises::then(onFulfilled = function(x) {
|
130 |
cat("Model has been loaded!", "\n")
|
|
|
4 |
library(magrittr)
|
5 |
source("model-session.R")
|
6 |
|
7 |
+
repo <- "EleutherAI/pythia-70m"
|
8 |
repo <- Sys.getenv("MODEL_REPO", unset = repo)
|
9 |
sess <- model_session$new()
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
max_n_tokens <- 100
|
12 |
system_prompt = "<|SYSTEM|># StableLM Tuned (Alpha version)
|
13 |
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
|
|
118 |
|
119 |
# this runs for the cases where sess$is_loaded was NULL
|
120 |
# ie there was no model currently loading.
|
|
|
121 |
m <- model_loaded() %>%
|
122 |
promises::then(onFulfilled = function(x) {
|
123 |
cat("Model has been loaded!", "\n")
|
model-session.R
CHANGED
@@ -1,28 +1,20 @@
|
|
1 |
-
source("promise-session.R")
|
2 |
|
3 |
-
# A wrapper a around the promise session that controls model loading and
|
4 |
-
# querying given a prompt
|
5 |
model_session <- R6::R6Class(
|
6 |
lock_objects = FALSE,
|
7 |
public = list(
|
8 |
initialize = function() {
|
9 |
-
self$
|
10 |
self$temperature <- 1
|
11 |
self$top_k <- 50
|
12 |
self$is_loaded <- NULL
|
13 |
},
|
14 |
-
poll_process = function(timeout = 1) {
|
15 |
-
if (!is.null(self$sess)) {
|
16 |
-
self$sess$poll_process(timeout)
|
17 |
-
}
|
18 |
-
},
|
19 |
load_model = function(repo) {
|
20 |
if (!is.null(self$sess)) {
|
21 |
cat("Model is already loaded.", "\n")
|
22 |
-
return(self$
|
23 |
}
|
24 |
-
self$
|
25 |
-
self$
|
26 |
library(torch)
|
27 |
library(zeallot)
|
28 |
library(minhub)
|
@@ -34,16 +26,16 @@ model_session <- R6::R6Class(
|
|
34 |
})
|
35 |
},
|
36 |
generate = function(prompt) {
|
37 |
-
if (is.null(self$
|
38 |
cat("Model is not loaded, error.", "\n")
|
39 |
-
return(self$
|
40 |
}
|
41 |
args <- list(
|
42 |
prompt = prompt,
|
43 |
temperature = self$temperature,
|
44 |
top_k = self$top_k
|
45 |
)
|
46 |
-
self$
|
47 |
idx <- torch_tensor(tok$encode(prompt)$ids)$view(c(1, -1))
|
48 |
with_no_grad({
|
49 |
logits <- model(idx + 1L)
|
|
|
|
|
1 |
|
|
|
|
|
2 |
model_session <- R6::R6Class(
|
3 |
lock_objects = FALSE,
|
4 |
public = list(
|
5 |
initialize = function() {
|
6 |
+
self$task_q <- NULL
|
7 |
self$temperature <- 1
|
8 |
self$top_k <- 50
|
9 |
self$is_loaded <- NULL
|
10 |
},
|
|
|
|
|
|
|
|
|
|
|
11 |
load_model = function(repo) {
|
12 |
if (!is.null(self$sess)) {
|
13 |
cat("Model is already loaded.", "\n")
|
14 |
+
return(self$task_q$push(function() "done"))
|
15 |
}
|
16 |
+
self$task_q <- callq::task_q$new(num_workers = 1)
|
17 |
+
self$task_q$push(args = list(repo = repo), function(repo) {
|
18 |
library(torch)
|
19 |
library(zeallot)
|
20 |
library(minhub)
|
|
|
26 |
})
|
27 |
},
|
28 |
generate = function(prompt) {
|
29 |
+
if (is.null(self$task_q)) {
|
30 |
cat("Model is not loaded, error.", "\n")
|
31 |
+
return(self$task_q$push(function() stop("Model is not loaded")))
|
32 |
}
|
33 |
args <- list(
|
34 |
prompt = prompt,
|
35 |
temperature = self$temperature,
|
36 |
top_k = self$top_k
|
37 |
)
|
38 |
+
self$task_q$push(args = args, function(prompt, temperature, top_k) {
|
39 |
idx <- torch_tensor(tok$encode(prompt)$ids)$view(c(1, -1))
|
40 |
with_no_grad({
|
41 |
logits <- model(idx + 1L)
|
promise-session.R
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
|
2 |
-
# Small utility class that wraps a `callr::r_session` to return promises when
|
3 |
-
# executing `sess$call()`.
|
4 |
-
# Only one promise is resolve per time in fifo way.
|
5 |
-
promise_session <- R6::R6Class(
|
6 |
-
lock_objects = FALSE,
|
7 |
-
public = list(
|
8 |
-
initialize = function() {
|
9 |
-
self$sess <- callr::r_session$new()
|
10 |
-
self$is_running <- FALSE
|
11 |
-
},
|
12 |
-
call = function(func, args = list()) {
|
13 |
-
self$poll_process()
|
14 |
-
promises::promise(function(resolve, reject) {
|
15 |
-
self$push_task(func, args, resolve, reject)
|
16 |
-
later::later(self$poll_process, 1)
|
17 |
-
})
|
18 |
-
},
|
19 |
-
push_task = function(func, args, resolve, reject) {
|
20 |
-
self$tasks[[length(self$tasks) + 1]] <- list(
|
21 |
-
func = func,
|
22 |
-
args = args,
|
23 |
-
resolve = resolve,
|
24 |
-
reject = reject
|
25 |
-
)
|
26 |
-
cat("task pushed, now we have ", length(self$tasks), " on queue\n")
|
27 |
-
self$run_task()
|
28 |
-
invisible(NULL)
|
29 |
-
},
|
30 |
-
run_task = function() {
|
31 |
-
if (self$is_running) return(NULL)
|
32 |
-
if (length(self$tasks) == 0) return(NULL)
|
33 |
-
|
34 |
-
self$is_running <- TRUE
|
35 |
-
task <- self$tasks[[1]]
|
36 |
-
self$sess$call(task$func, args = task$args)
|
37 |
-
},
|
38 |
-
resolve_task = function() {
|
39 |
-
cat("Resolving task! ")
|
40 |
-
out <- self$sess$read()
|
41 |
-
if (!is.null(out$error)) {
|
42 |
-
self$tasks[[1]]$reject(out$error)
|
43 |
-
} else {
|
44 |
-
self$tasks[[1]]$resolve(out$result)
|
45 |
-
}
|
46 |
-
|
47 |
-
self$tasks <- self$tasks[-1]
|
48 |
-
cat("now we have ", length(self$tasks), "on queue\n")
|
49 |
-
|
50 |
-
self$is_running <- FALSE
|
51 |
-
|
52 |
-
self$run_task()
|
53 |
-
},
|
54 |
-
poll_process = function(timeout = 1) {
|
55 |
-
if (!self$is_running) return("ready")
|
56 |
-
poll_state <- self$sess$poll_process(timeout)
|
57 |
-
if (poll_state == "ready") {
|
58 |
-
self$resolve_task()
|
59 |
-
}
|
60 |
-
poll_state
|
61 |
-
}
|
62 |
-
)
|
63 |
-
)
|
64 |
-
|
65 |
-
# sess <- promise_session$new()
|
66 |
-
# f <- sess$call(function(a) {
|
67 |
-
# 10 + 1
|
68 |
-
# }, list(1))
|
69 |
-
# sess$poll_process()
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|