File size: 5,533 Bytes
3c0d933
 
ca30460
b73694e
8a3fba7
ca30460
93d628c
 
8a3fba7
3c0d933
8a3fba7
 
 
 
 
 
 
3c0d933
8a3fba7
 
 
 
 
 
 
 
 
 
 
 
 
33145ed
3c0d933
 
 
 
8a3fba7
 
0313f74
8a3fba7
 
 
 
3c0d933
8a3fba7
896d280
0313f74
8a3fba7
 
 
 
 
 
 
896d280
 
 
 
0313f74
896d280
 
 
 
8a3fba7
 
 
 
 
896d280
8a3fba7
 
0313f74
8a3fba7
0313f74
8a3fba7
 
 
 
 
 
 
 
 
 
 
 
33145ed
 
 
 
 
 
 
8a3fba7
 
0313f74
 
 
 
8f4e035
 
 
 
 
 
0313f74
 
 
 
 
 
 
 
c29126b
 
0313f74
33145ed
8f4e035
 
 
 
 
 
c29126b
896d280
8f4e035
 
c29126b
0313f74
55d99af
0313f74
55d99af
 
896d280
0313f74
55d99af
 
9ae3123
3f76e67
c29126b
0313f74
 
 
 
 
55d99af
0313f74
8a3fba7
3c0d933
 
0313f74
8a3fba7
0313f74
 
 
8a3fba7
 
 
 
0313f74
 
 
8a3fba7
 
0313f74
8a3fba7
 
 
 
 
 
0313f74
8a3fba7
 
 
3c0d933
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
library(shiny)
library(bslib)
library(minhub)
library(magrittr)
source("model-session.R")

repo <- "EleutherAI/pythia-70m"
repo <- Sys.getenv("MODEL_REPO", unset = repo)
sess <- model_session$new()

max_n_tokens <- 100
system_prompt = "<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"

ui <- page_fillable(
  theme = bs_theme(bootswatch = "minty"),
  shinyjs::useShinyjs(),
  card(
    height="90%",
    heights_equal = "row",
    width = 1,
    fillable = FALSE,
    card_body(id = "messages", gap = 5, fillable = FALSE)
  ),
  layout_column_wrap(
    width = 1/2,
    textInput("prompt", label = NULL, width="100%"),
    actionButton("send", "Send", width = "100%")
  )
)

server <- function(input, output, session) {
  prompt <- reactiveVal(value = system_prompt)
  n_tokens <- reactiveVal(value = 0)
  msg_id <- reactiveVal(value = 0)
  
  observeEvent(input$send, {
    if (is.null(input$prompt) || input$prompt == "") {
      return()
    }
    shinyjs::disable("send")
    updateActionButton(inputId = "send", label = "Waiting for model...")
    insert_message(msg_id, as.character(glue::glue("🤗: {input$prompt}")))  
    
    # we modify the prompt to trigger the 'next_token' reactive
    prompt(paste0(prompt(), "<|USER|>", input$prompt, "<|ASSISTANT|>")) 
  })
  
  next_token <- eventReactive(prompt(), ignoreInit = TRUE, {
    prompt() %>% 
      sess$generate() %>% 
      promises::then(
        onFulfilled = function(x) {x},
        onRejected = function(x) {
          insert_message(msg_id, paste0("😭 Error generating token.", as.character(x)))
          updateActionButton(inputId = "send", label = "Failing generation. Contact admin.")
          NULL
        }
      )
  })
  
  observeEvent(next_token(), {
    tok <- next_token()
    
    n_tokens(n_tokens() + 1)
    tok %>% promises::then(function(tok) {
      if (n_tokens() == 1) {
        insert_message(msg_id, paste0("🤖: ", tok), append = FALSE)
      } else {
        insert_message(msg_id, tok, append = TRUE)
      }
      
      if (tok != "" && n_tokens() < max_n_tokens) {
        prompt(paste0(prompt(), tok))
      } else {
        shinyjs::enable("send")
        updateActionButton(inputId = "send", label = "Send")
        n_tokens(0)
      }
    })
  })
  
  observe({
    # an observer that makes sure tasks are resolved during the shiny loop
    invalidateLater(5000, session)
    if (!is.null(sess$sess))
      sess$sess$poll_process(1)
  })
  
  # Observer used at app startup time to allow using the 'Send' button once the
  # model has been loaded.
  model_loaded <- reactiveVal()
  event_reload <- reactiveVal(val = 0)
  observeEvent(event_reload(), ignoreNULL=FALSE, {
    
    # the model is already loaded, we just make sure the send button is enabled
    if (!is.null(sess$is_loaded) && sess$is_loaded) {
      shinyjs::enable("send")
      updateActionButton(inputId = "send", label = "Send")
      return()
    } 
    
    # the model isn't loaded, this we disable the send button and
    # show that we are loading the model
    shinyjs::disable("send")
    updateActionButton(inputId = "send", label = "Loading the model...")
    
    # the model isn't loaded and no task is trying to load it, so we start a new
    # task to load it
    if (is.null(sess$is_loaded)) {
      cat("Started loading model ....", "\n")
      model_loaded(sess$load_model(repo))
      sess$is_loaded <- FALSE # not yet loaded, but loading
    } else {
      # the model is loading, but this is handled by another session. We should
      # come back to this observer later to enable the send button once model
      # is loaded.
      invalidateLater(5000, session)
      return()
    }
    
    # this runs for the cases where sess$is_loaded was NULL
    # ie there was no model currently loading.
    cat("Loading model:",sess$sess$poll_process(), "\n")
    m <- model_loaded() %>% 
      promises::then(onFulfilled = function(x) {
        cat("Model has been loaded!", "\n")
        shinyjs::enable("send")
        updateActionButton(inputId = "send", label = "Send")
        sess$is_loaded <- TRUE
        TRUE
      }, onRejected = function(x) {
        shinyjs::disable("send")
        insert_message(msg_id, paste0("😭 Error loading the model:\n", as.character(x)))
        sess$is_loaded <- NULL # means failure!
        sess$sess <- NULL
        if (event_reload() < 10) {
          Sys.sleep(5)
          event_reload(event_reload() + 1)
        }
        FALSE
      })
    model_loaded(m)
  })
}

insert_message <- function(message_id, msg, append = FALSE) {
  if (!append) {
    id <- message_id() + 1
    message_id(id)
    
    insertUI(
      "#messages", 
      "beforeEnd", 
      immediate = TRUE,
      ui = card(style="margin-bottom:5px;", card_body(
        p(id = paste0("msg-",id), msg)
      ))
    )
  } else {
    id <- message_id()
    shinyjs::runjs(glue::glue(
      "document.getElementById('msg-{id}').textContent += '{msg}'"
    ))
  }
  # scroll to bottom
  shinyjs::runjs("var elem = document.getElementById('messages'); elem.scrollTop = elem.scrollHeight;")
  id
}


shinyApp(ui, server)