dfalbel commited on
Commit
8a3fba7
1 Parent(s): dc41f47

Make the app nicer!

Browse files
Files changed (5) hide show
  1. Dockerfile +2 -1
  2. app.R +109 -37
  3. gptneox.Rproj +13 -0
  4. model-session.R +45 -0
  5. promise-session.R +67 -0
Dockerfile CHANGED
@@ -5,7 +5,8 @@ WORKDIR /code
5
  # Install stable packages from CRAN
6
  RUN install2.r --error \
7
  ggExtra \
8
- shiny
 
9
 
10
  # Install Rust for tok
11
 
 
5
  # Install stable packages from CRAN
6
  RUN install2.r --error \
7
  ggExtra \
8
+ shiny \
9
+ callr
10
 
11
  # Install Rust for tok
12
 
app.R CHANGED
@@ -1,53 +1,125 @@
1
  library(shiny)
2
  library(bslib)
3
- library(dplyr)
4
- library(ggplot2)
5
  library(minhub)
 
6
 
7
- model <- gptneox()
 
 
8
 
9
- # Find subset of columns that are suitable for scatter plot
10
- df_num <- df |> select(where(is.numeric), -Year)
 
 
 
 
 
11
 
12
- ui <- page_fillable(theme = bs_theme(bootswatch = "minty"),
13
- layout_sidebar(fillable = TRUE,
14
- sidebar(
15
- varSelectInput("xvar", "X variable", df_num, selected = "Bill Length (mm)"),
16
- varSelectInput("yvar", "Y variable", df_num, selected = "Bill Depth (mm)"),
17
- checkboxGroupInput("species", "Filter by species",
18
- choices = unique(df$Species), selected = unique(df$Species)
19
- ),
20
- hr(), # Add a horizontal rule
21
- checkboxInput("by_species", "Show species", TRUE),
22
- checkboxInput("show_margins", "Show marginal plots", TRUE),
23
- checkboxInput("smooth", "Add smoother"),
24
- ),
25
- plotOutput("scatter")
26
  )
27
  )
28
 
29
  server <- function(input, output, session) {
30
- subsetted <- reactive({
31
- req(input$species)
32
- df |> filter(Species %in% input$species)
33
- })
34
 
35
- output$scatter <- renderPlot({
36
- p <- ggplot(subsetted(), aes(!!input$xvar, !!input$yvar)) + list(
37
- theme(legend.position = "bottom"),
38
- if (input$by_species) aes(color=Species),
39
- geom_point(),
40
- if (input$smooth) geom_smooth()
41
- )
42
-
43
- if (input$show_margins) {
44
- margin_type <- if (input$by_species) "density" else "histogram"
45
- p <- p |> ggExtra::ggMarginal(type = margin_type, margins = "both",
46
- size = 8, groupColour = input$by_species, groupFill = input$by_species)
47
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- p
50
- }, res = 100)
 
 
 
 
 
 
 
 
 
51
  }
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  shinyApp(ui, server)
 
1
  library(shiny)
2
  library(bslib)
 
 
3
  library(minhub)
4
+ source("model-session.R")
5
 
6
+ repo <- "stabilityai/stablelm-tuned-alpha-3b"
7
+ sess <- model_session$new()
8
+ model_loaded <- sess$load_model(repo)
9
 
10
+ max_n_tokens <- 100
11
+ system_prompt = "<|SYSTEM|># StableLM Tuned (Alpha version)
12
+ - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
13
+ - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
14
+ - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
15
+ - StableLM will refuse to participate in anything that could harm a human.
16
+ "
17
 
18
+ ui <- page_fillable(
19
+ theme = bs_theme(bootswatch = "minty"),
20
+ shinyjs::useShinyjs(),
21
+ card(
22
+ height="90%",
23
+ heights_equal = "row",
24
+ width = 1,
25
+ fillable = FALSE,
26
+ card_body(id = "messages", gap = 5, fillable = FALSE)
27
+ ),
28
+ layout_column_wrap(
29
+ width = 1/2,
30
+ textInput("prompt", label = NULL, width="100%"),
31
+ actionButton("send", "Loading model...", width = "100%")
32
  )
33
  )
34
 
35
  server <- function(input, output, session) {
 
 
 
 
36
 
37
+ prompt <- reactiveVal(value = system_prompt)
38
+ n_tokens <- reactiveVal(value = 0)
39
+
40
+ observeEvent(input$send, {
41
+ if (is.null(input$prompt) || input$prompt == "") {
42
+ return()
 
 
 
 
 
 
43
  }
44
+ shinyjs::disable("send")
45
+ updateActionButton(inputId = "send", label = "Waiting for model")
46
+ insert_message(as.character(glue::glue("🤗: {input$prompt}")))
47
+
48
+ # we modify the prompt to trigger the 'next_token' reactive
49
+ prompt(paste0(prompt(), "<|USER|>", input$prompt, "<|ASSISTANT|>"))
50
+ })
51
+
52
+ next_token <- eventReactive(prompt(), ignoreInit = TRUE, {
53
+ prompt() %>%
54
+ sess$generate()
55
+ })
56
+
57
+ observeEvent(next_token(), {
58
+ tok <- next_token()
59
+ n_tokens(n_tokens() + 1)
60
+
61
+ tok %>% promises::then(function(tok) {
62
+ if (n_tokens() == 1) {
63
+ insert_message(paste0("🤖: ", tok), append = FALSE)
64
+ } else {
65
+ insert_message(tok, append = TRUE)
66
+ }
67
+
68
+ if (tok != "" && n_tokens() < max_n_tokens) {
69
+ prompt(paste0(prompt(), tok))
70
+ } else {
71
+ shinyjs::enable("send")
72
+ updateActionButton(inputId = "send", label = "Send")
73
+ n_tokens(0)
74
+ }
75
+ })
76
+ })
77
+
78
+ # we need this observer to make sure that during the event loop the
79
+ # tasks are resolved.
80
+ observe({
81
+ invalidateLater(5000, session)
82
+ sess$sess$poll_process(1)
83
+ })
84
+
85
+ # Observer used at app startup time to allow using the 'Send' button once the
86
+ # model has been loaded.
87
+ observe({
88
+ ready <- sess$sess$poll_process(1) == "ready"
89
+ send <- isolate(input$send)
90
 
91
+ if (send == 0 && !ready) {
92
+ invalidateLater(1000, session)
93
+ }
94
+
95
+ if (ready) {
96
+ shinyjs::enable("send")
97
+ updateActionButton(inputId = "send", label = "Send")
98
+ } else {
99
+ shinyjs::disable("send")
100
+ }
101
+ })
102
  }
103
 
104
+ message_id <- 0
105
+ insert_message <- function(msg, append = FALSE) {
106
+ if (!append) {
107
+ id <- message_id <<- message_id + 1
108
+ insertUI(
109
+ "#messages",
110
+ "beforeEnd",
111
+ immediate = TRUE,
112
+ ui = card(card_body(p(id = paste0("msg-",id), msg)), style="margin-bottom:5px;")
113
+ )
114
+ } else {
115
+ id <- message_id
116
+ shinyjs::runjs(glue::glue(
117
+ "document.getElementById('msg-{id}').textContent += '{msg}'"
118
+ ))
119
+ }
120
+ # scroll to bottom
121
+ shinyjs::runjs("var elem = document.getElementById('messages'); elem.scrollTop = elem.scrollHeight;")
122
+ }
123
+
124
+
125
  shinyApp(ui, server)
gptneox.Rproj ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Version: 1.0
2
+
3
+ RestoreWorkspace: Default
4
+ SaveWorkspace: Default
5
+ AlwaysSaveHistory: Default
6
+
7
+ EnableCodeIndexing: Yes
8
+ UseSpacesForTab: Yes
9
+ NumSpacesForTab: 2
10
+ Encoding: UTF-8
11
+
12
+ RnwWeave: knitr
13
+ LaTeX: pdfLaTeX
model-session.R ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source("promise-session.R")
2
+
3
+ # A wrapper a around the promise session that controls model loading and
4
+ # querying given a prompt
5
+ model_session <- R6::R6Class(
6
+ lock_objects = FALSE,
7
+ public = list(
8
+ initialize = function() {
9
+ self$sess <- promise_session$new()
10
+ self$temperature <- 1
11
+ self$top_k <- 50
12
+ },
13
+ load_model = function(repo) {
14
+ self$sess$call(args = list(repo = repo), function(repo) {
15
+ library(torch)
16
+ library(zeallot)
17
+ library(minhub)
18
+ model <<- minhub::gptneox_from_pretrained(repo)
19
+ model$eval()
20
+ model$to(dtype = torch_float())
21
+ tok <<- tok::tokenizer$from_pretrained(repo)
22
+ "done"
23
+ })
24
+ },
25
+ generate = function(prompt) {
26
+ args <- list(
27
+ prompt = prompt,
28
+ temperature = self$temperature,
29
+ top_k = self$top_k
30
+ )
31
+ self$sess$call(args = args, function(prompt, temperature, top_k) {
32
+ idx <- torch_tensor(tok$encode(prompt)$ids)$view(c(1, -1))
33
+ with_no_grad({
34
+ logits <- model(idx + 1L)
35
+ })
36
+ logits <- logits[,-1,]/temperature
37
+ c(prob, ind) %<-% logits$topk(top_k)
38
+ logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob)
39
+ logits <- nnf_softmax(logits, dim = -1)
40
+ id_next <- torch::torch_multinomial(logits, num_samples = 1) - 1L
41
+ tok$decode(as.integer(id_next))
42
+ })
43
+ }
44
+ )
45
+ )
promise-session.R ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Small utility class that wraps a `callr::r_session` to return promises when
3
+ # executing `sess$call()`.
4
+ # Only one promise is resolve per time in fifo way.
5
+ promise_session <- R6::R6Class(
6
+ lock_objects = FALSE,
7
+ public = list(
8
+ initialize = function() {
9
+ self$sess <- callr::r_session$new()
10
+ self$is_running <- FALSE
11
+ },
12
+ call = function(func, args = list()) {
13
+ self$poll_process()
14
+ promises::promise(function(resolve, reject) {
15
+ self$push_task(func, args, resolve, reject)
16
+ later::later(self$poll_process, 1)
17
+ })
18
+ },
19
+ push_task = function(func, args, resolve, reject) {
20
+ self$tasks[[length(self$tasks) + 1]] <- list(
21
+ func = func,
22
+ args = args,
23
+ resolve = resolve,
24
+ reject = reject
25
+ )
26
+ cat("task pushed, now we have ", length(self$tasks), " on queue\n")
27
+ self$run_task()
28
+ invisible(NULL)
29
+ },
30
+ run_task = function() {
31
+ if (self$is_running) return(NULL)
32
+ if (length(self$tasks) == 0) return(NULL)
33
+
34
+ self$is_running <- TRUE
35
+ task <- self$tasks[[1]]
36
+ self$sess$call(task$func, args = task$args)
37
+ },
38
+ resolve_task = function() {
39
+ out <- self$sess$read()
40
+ if (!is.null(out$error)) {
41
+ self$tasks[[1]]$reject(out$error)
42
+ } else {
43
+ self$tasks[[1]]$resolve(out$result)
44
+ }
45
+
46
+ self$tasks <- self$tasks[-1]
47
+ self$is_running <- FALSE
48
+
49
+ self$run_task()
50
+ },
51
+ poll_process = function(timeout = 1) {
52
+ if (!self$is_running) return("ready")
53
+ poll_state <- self$sess$poll_process(timeout)
54
+ if (poll_state == "ready") {
55
+ self$resolve_task()
56
+ }
57
+ poll_state
58
+ }
59
+ )
60
+ )
61
+
62
+ # sess <- promise_session$new()
63
+ # f <- sess$call(function(a) {
64
+ # 10 + 1
65
+ # }, list(1))
66
+ # sess$poll_process()
67
+