Spaces:
Runtime error
Runtime error
Make the app nicer!
Browse files- Dockerfile +2 -1
- app.R +109 -37
- gptneox.Rproj +13 -0
- model-session.R +45 -0
- promise-session.R +67 -0
Dockerfile
CHANGED
@@ -5,7 +5,8 @@ WORKDIR /code
|
|
5 |
# Install stable packages from CRAN
|
6 |
RUN install2.r --error \
|
7 |
ggExtra \
|
8 |
-
shiny
|
|
|
9 |
|
10 |
# Install Rust for tok
|
11 |
|
|
|
5 |
# Install stable packages from CRAN
|
6 |
RUN install2.r --error \
|
7 |
ggExtra \
|
8 |
+
shiny \
|
9 |
+
callr
|
10 |
|
11 |
# Install Rust for tok
|
12 |
|
app.R
CHANGED
@@ -1,53 +1,125 @@
|
|
1 |
library(shiny)
|
2 |
library(bslib)
|
3 |
-
library(dplyr)
|
4 |
-
library(ggplot2)
|
5 |
library(minhub)
|
|
|
6 |
|
7 |
-
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
ui <- page_fillable(
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
),
|
25 |
-
|
26 |
)
|
27 |
)
|
28 |
|
29 |
server <- function(input, output, session) {
|
30 |
-
subsetted <- reactive({
|
31 |
-
req(input$species)
|
32 |
-
df |> filter(Species %in% input$species)
|
33 |
-
})
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
)
|
42 |
-
|
43 |
-
if (input$show_margins) {
|
44 |
-
margin_type <- if (input$by_species) "density" else "histogram"
|
45 |
-
p <- p |> ggExtra::ggMarginal(type = margin_type, margins = "both",
|
46 |
-
size = 8, groupColour = input$by_species, groupFill = input$by_species)
|
47 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
}
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
shinyApp(ui, server)
|
|
|
1 |
library(shiny)
|
2 |
library(bslib)
|
|
|
|
|
3 |
library(minhub)
|
4 |
+
source("model-session.R")
|
5 |
|
6 |
+
repo <- "stabilityai/stablelm-tuned-alpha-3b"
|
7 |
+
sess <- model_session$new()
|
8 |
+
model_loaded <- sess$load_model(repo)
|
9 |
|
10 |
+
max_n_tokens <- 100
|
11 |
+
system_prompt = "<|SYSTEM|># StableLM Tuned (Alpha version)
|
12 |
+
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
13 |
+
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
14 |
+
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
15 |
+
- StableLM will refuse to participate in anything that could harm a human.
|
16 |
+
"
|
17 |
|
18 |
+
ui <- page_fillable(
|
19 |
+
theme = bs_theme(bootswatch = "minty"),
|
20 |
+
shinyjs::useShinyjs(),
|
21 |
+
card(
|
22 |
+
height="90%",
|
23 |
+
heights_equal = "row",
|
24 |
+
width = 1,
|
25 |
+
fillable = FALSE,
|
26 |
+
card_body(id = "messages", gap = 5, fillable = FALSE)
|
27 |
+
),
|
28 |
+
layout_column_wrap(
|
29 |
+
width = 1/2,
|
30 |
+
textInput("prompt", label = NULL, width="100%"),
|
31 |
+
actionButton("send", "Loading model...", width = "100%")
|
32 |
)
|
33 |
)
|
34 |
|
35 |
server <- function(input, output, session) {
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
prompt <- reactiveVal(value = system_prompt)
|
38 |
+
n_tokens <- reactiveVal(value = 0)
|
39 |
+
|
40 |
+
observeEvent(input$send, {
|
41 |
+
if (is.null(input$prompt) || input$prompt == "") {
|
42 |
+
return()
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
}
|
44 |
+
shinyjs::disable("send")
|
45 |
+
updateActionButton(inputId = "send", label = "Waiting for model")
|
46 |
+
insert_message(as.character(glue::glue("🤗: {input$prompt}")))
|
47 |
+
|
48 |
+
# we modify the prompt to trigger the 'next_token' reactive
|
49 |
+
prompt(paste0(prompt(), "<|USER|>", input$prompt, "<|ASSISTANT|>"))
|
50 |
+
})
|
51 |
+
|
52 |
+
next_token <- eventReactive(prompt(), ignoreInit = TRUE, {
|
53 |
+
prompt() %>%
|
54 |
+
sess$generate()
|
55 |
+
})
|
56 |
+
|
57 |
+
observeEvent(next_token(), {
|
58 |
+
tok <- next_token()
|
59 |
+
n_tokens(n_tokens() + 1)
|
60 |
+
|
61 |
+
tok %>% promises::then(function(tok) {
|
62 |
+
if (n_tokens() == 1) {
|
63 |
+
insert_message(paste0("🤖: ", tok), append = FALSE)
|
64 |
+
} else {
|
65 |
+
insert_message(tok, append = TRUE)
|
66 |
+
}
|
67 |
+
|
68 |
+
if (tok != "" && n_tokens() < max_n_tokens) {
|
69 |
+
prompt(paste0(prompt(), tok))
|
70 |
+
} else {
|
71 |
+
shinyjs::enable("send")
|
72 |
+
updateActionButton(inputId = "send", label = "Send")
|
73 |
+
n_tokens(0)
|
74 |
+
}
|
75 |
+
})
|
76 |
+
})
|
77 |
+
|
78 |
+
# we need this observer to make sure that during the event loop the
|
79 |
+
# tasks are resolved.
|
80 |
+
observe({
|
81 |
+
invalidateLater(5000, session)
|
82 |
+
sess$sess$poll_process(1)
|
83 |
+
})
|
84 |
+
|
85 |
+
# Observer used at app startup time to allow using the 'Send' button once the
|
86 |
+
# model has been loaded.
|
87 |
+
observe({
|
88 |
+
ready <- sess$sess$poll_process(1) == "ready"
|
89 |
+
send <- isolate(input$send)
|
90 |
|
91 |
+
if (send == 0 && !ready) {
|
92 |
+
invalidateLater(1000, session)
|
93 |
+
}
|
94 |
+
|
95 |
+
if (ready) {
|
96 |
+
shinyjs::enable("send")
|
97 |
+
updateActionButton(inputId = "send", label = "Send")
|
98 |
+
} else {
|
99 |
+
shinyjs::disable("send")
|
100 |
+
}
|
101 |
+
})
|
102 |
}
|
103 |
|
104 |
+
message_id <- 0
|
105 |
+
insert_message <- function(msg, append = FALSE) {
|
106 |
+
if (!append) {
|
107 |
+
id <- message_id <<- message_id + 1
|
108 |
+
insertUI(
|
109 |
+
"#messages",
|
110 |
+
"beforeEnd",
|
111 |
+
immediate = TRUE,
|
112 |
+
ui = card(card_body(p(id = paste0("msg-",id), msg)), style="margin-bottom:5px;")
|
113 |
+
)
|
114 |
+
} else {
|
115 |
+
id <- message_id
|
116 |
+
shinyjs::runjs(glue::glue(
|
117 |
+
"document.getElementById('msg-{id}').textContent += '{msg}'"
|
118 |
+
))
|
119 |
+
}
|
120 |
+
# scroll to bottom
|
121 |
+
shinyjs::runjs("var elem = document.getElementById('messages'); elem.scrollTop = elem.scrollHeight;")
|
122 |
+
}
|
123 |
+
|
124 |
+
|
125 |
shinyApp(ui, server)
|
gptneox.Rproj
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Version: 1.0
|
2 |
+
|
3 |
+
RestoreWorkspace: Default
|
4 |
+
SaveWorkspace: Default
|
5 |
+
AlwaysSaveHistory: Default
|
6 |
+
|
7 |
+
EnableCodeIndexing: Yes
|
8 |
+
UseSpacesForTab: Yes
|
9 |
+
NumSpacesForTab: 2
|
10 |
+
Encoding: UTF-8
|
11 |
+
|
12 |
+
RnwWeave: knitr
|
13 |
+
LaTeX: pdfLaTeX
|
model-session.R
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
source("promise-session.R")
|
2 |
+
|
3 |
+
# A wrapper a around the promise session that controls model loading and
|
4 |
+
# querying given a prompt
|
5 |
+
model_session <- R6::R6Class(
|
6 |
+
lock_objects = FALSE,
|
7 |
+
public = list(
|
8 |
+
initialize = function() {
|
9 |
+
self$sess <- promise_session$new()
|
10 |
+
self$temperature <- 1
|
11 |
+
self$top_k <- 50
|
12 |
+
},
|
13 |
+
load_model = function(repo) {
|
14 |
+
self$sess$call(args = list(repo = repo), function(repo) {
|
15 |
+
library(torch)
|
16 |
+
library(zeallot)
|
17 |
+
library(minhub)
|
18 |
+
model <<- minhub::gptneox_from_pretrained(repo)
|
19 |
+
model$eval()
|
20 |
+
model$to(dtype = torch_float())
|
21 |
+
tok <<- tok::tokenizer$from_pretrained(repo)
|
22 |
+
"done"
|
23 |
+
})
|
24 |
+
},
|
25 |
+
generate = function(prompt) {
|
26 |
+
args <- list(
|
27 |
+
prompt = prompt,
|
28 |
+
temperature = self$temperature,
|
29 |
+
top_k = self$top_k
|
30 |
+
)
|
31 |
+
self$sess$call(args = args, function(prompt, temperature, top_k) {
|
32 |
+
idx <- torch_tensor(tok$encode(prompt)$ids)$view(c(1, -1))
|
33 |
+
with_no_grad({
|
34 |
+
logits <- model(idx + 1L)
|
35 |
+
})
|
36 |
+
logits <- logits[,-1,]/temperature
|
37 |
+
c(prob, ind) %<-% logits$topk(top_k)
|
38 |
+
logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob)
|
39 |
+
logits <- nnf_softmax(logits, dim = -1)
|
40 |
+
id_next <- torch::torch_multinomial(logits, num_samples = 1) - 1L
|
41 |
+
tok$decode(as.integer(id_next))
|
42 |
+
})
|
43 |
+
}
|
44 |
+
)
|
45 |
+
)
|
promise-session.R
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Small utility class that wraps a `callr::r_session` to return promises when
|
3 |
+
# executing `sess$call()`.
|
4 |
+
# Only one promise is resolve per time in fifo way.
|
5 |
+
promise_session <- R6::R6Class(
|
6 |
+
lock_objects = FALSE,
|
7 |
+
public = list(
|
8 |
+
initialize = function() {
|
9 |
+
self$sess <- callr::r_session$new()
|
10 |
+
self$is_running <- FALSE
|
11 |
+
},
|
12 |
+
call = function(func, args = list()) {
|
13 |
+
self$poll_process()
|
14 |
+
promises::promise(function(resolve, reject) {
|
15 |
+
self$push_task(func, args, resolve, reject)
|
16 |
+
later::later(self$poll_process, 1)
|
17 |
+
})
|
18 |
+
},
|
19 |
+
push_task = function(func, args, resolve, reject) {
|
20 |
+
self$tasks[[length(self$tasks) + 1]] <- list(
|
21 |
+
func = func,
|
22 |
+
args = args,
|
23 |
+
resolve = resolve,
|
24 |
+
reject = reject
|
25 |
+
)
|
26 |
+
cat("task pushed, now we have ", length(self$tasks), " on queue\n")
|
27 |
+
self$run_task()
|
28 |
+
invisible(NULL)
|
29 |
+
},
|
30 |
+
run_task = function() {
|
31 |
+
if (self$is_running) return(NULL)
|
32 |
+
if (length(self$tasks) == 0) return(NULL)
|
33 |
+
|
34 |
+
self$is_running <- TRUE
|
35 |
+
task <- self$tasks[[1]]
|
36 |
+
self$sess$call(task$func, args = task$args)
|
37 |
+
},
|
38 |
+
resolve_task = function() {
|
39 |
+
out <- self$sess$read()
|
40 |
+
if (!is.null(out$error)) {
|
41 |
+
self$tasks[[1]]$reject(out$error)
|
42 |
+
} else {
|
43 |
+
self$tasks[[1]]$resolve(out$result)
|
44 |
+
}
|
45 |
+
|
46 |
+
self$tasks <- self$tasks[-1]
|
47 |
+
self$is_running <- FALSE
|
48 |
+
|
49 |
+
self$run_task()
|
50 |
+
},
|
51 |
+
poll_process = function(timeout = 1) {
|
52 |
+
if (!self$is_running) return("ready")
|
53 |
+
poll_state <- self$sess$poll_process(timeout)
|
54 |
+
if (poll_state == "ready") {
|
55 |
+
self$resolve_task()
|
56 |
+
}
|
57 |
+
poll_state
|
58 |
+
}
|
59 |
+
)
|
60 |
+
)
|
61 |
+
|
62 |
+
# sess <- promise_session$new()
|
63 |
+
# f <- sess$call(function(a) {
|
64 |
+
# 10 + 1
|
65 |
+
# }, list(1))
|
66 |
+
# sess$poll_process()
|
67 |
+
|