File size: 8,944 Bytes
3c0d933
 
ca30460
b73694e
8a3fba7
ca30460
13c1f55
f3264d9
93d628c
8a3fba7
3c0d933
8a3fba7
13c1f55
8a3fba7
 
 
 
 
3c0d933
8a3fba7
 
 
 
 
 
 
640a26a
8a3fba7
 
 
 
640a26a
3c0d933
 
 
 
640a26a
 
 
 
 
 
8a3fba7
 
640a26a
c2274c7
640a26a
 
8f4e035
640a26a
8f4e035
 
c2274c7
0313f74
 
c29126b
 
640a26a
33145ed
8f4e035
 
 
 
c2274c7
8f4e035
c29126b
896d280
8f4e035
 
640a26a
 
 
c2274c7
640a26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a3fba7
640a26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0313f74
640a26a
 
 
 
8a3fba7
640a26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2274c7
 
 
 
 
 
 
 
 
 
640a26a
c2274c7
 
640a26a
 
c2274c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640a26a
 
b1231d6
 
640a26a
 
 
 
 
b1231d6
 
640a26a
 
 
 
 
 
c2274c7
640a26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a3fba7
 
c2274c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c0d933
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
library(shiny)
library(bslib)
library(minhub)
library(magrittr)
source("model-session.R")

repo <- "EleutherAI/pythia-70m"
#repo <- "stabilityai/stablelm-tuned-alpha-3b"
repo <- Sys.getenv("MODEL_REPO", unset = repo)
sess <- model_session$new()

max_n_tokens <- 100
system_prompt <- "<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"

ui <- page_fillable(
  theme = bs_theme(bootswatch = "minty"),
  card(
    height="90%",
    heights_equal = "row",
    width = 1,
    fillable = FALSE,
    uiOutput("messages")
  ),
  layout_column_wrap(
    width = 1/2,
    textInput("prompt", label = NULL, width="100%"),
    uiOutput("sendButton")
  )
)

server <- function(input, output, session) {
  # context for the observers that load the model in the background session
  # it also handles reloads
  loading <- reactiveValues(
    model = NULL,
    reload = NULL
  )
  # Observer used at app startup time to allow using the 'Send' button once the
  # model has been loaded.
  observeEvent(loading$reload, ignoreInit = FALSE, ignoreNULL = FALSE, priority = 0, {
  
    # the model is already loaded, we just make sure that we propagate this
    # by setting generating to FALSE
    if (!is.null(sess$is_loaded) && sess$is_loaded) {
      context$generating <- FALSE
      return()
    } 

    # the model isn't loaded and no task is trying to load it, so we start a new
    # task to load it
    if (is.null(sess$is_loaded)) {
      cat("Started loading model ....", "\n")
      loading$model <- sess$load_model(repo)
      sess$is_loaded <- FALSE # not yet loaded, but loading
    } else {
      # the model is loading, but this is handled by another session. We should
      # come back to this observer later to enable the send button once model
      # is loaded.
      invalidateLater(100, session)
      return()
    }
    
    # this runs for the cases where sess$is_loaded was NULL
    # ie there was no model currently loading.
    m <- loading$model %>% 
      promises::then(
        onFulfilled = function(x) {
          Sys.sleep(5)
          cat("Model has been loaded!", "\n")
          context$generating <- FALSE
          sess$is_loaded <- TRUE
          TRUE
        }, 
        onRejected = function(x) {
          context$generating <- "error"
          msg <- list(
            role = "error", 
            content = paste0("Error loading the model:\n", as.character(x))
          )
          context$messages <- append(context$messages, list(msg))
          
          # setup for retry!
          sess$is_loaded <- NULL # means failure!
          sess$sess <- NULL
          if (loading$reload < 10) {
            Sys.sleep(5)
            loading$reload <- loading$reload + 1
          }
          FALSE
        })
    loading$model <- m
  })
  
  
  # context for generating messages
  context <- reactiveValues(
    generating = "loading", # a flag indicating if we are still generating tokens
    idxs = NULL, # the current sequence of tokens
    n_tokens = 0, # number of tokens already generated
    messages = list()
  )
  
  observeEvent(input$send, ignoreInit = TRUE, {
    # the is the observer for send message action button that triggers the rest 
    # of the reactions.
    # if the prompt is empty, there's nothing to do
    if (is.null(input$prompt) || input$prompt == "") {
      return()
    }
    
    # the user clicked 'send' and the prompt is not empty:
    # we will enter in generation mode
    context$generating <- TRUE
    
    # we add the user message into the messages list
    context$messages <- append(
      context$messages, 
      list(list(role = "user", content = input$prompt))
    )
    # ... and the start of the assistant message
    context$messages <- append(
      context$messages, 
      list(list(role = "assistant", content = ""))
    )
    
    # we also update the idxs context value with the newly added prompt
    # in case, this is the first send call, we also need to add the system
    # prompt
    if (is.null(context$idxs)) {
      context$idxs <- sess$tok$encode(system_prompt)$ids
    }
    
    # we now append the prompt. the prompt is wrapped around special tokens
    # for generation:
    prompt <- paste0("<|USER|>", input$prompt, "<|ASSISTANT|>")
    context$idxs <- c(context$idxs, sess$tok$encode(prompt)$ids)
    cat("Tokens in context: ", length(context$idxs), "\n")
  })
  
  observeEvent(context$generating, priority = 10, {
    # this controls the state of the send button.
    # if generating is TRUE we want it to be disabled, otherwise it's enabled
    # if generating is `NULL`, then the model is not yet loaded
    
    # in the first startup we have to insert the send button directly, otherwise
    # it would exist after all observers ran, including the one that loads the
    # model
    if (context$generating == "loading") {
      btn <- list(
        class = "btn-secondary disabled", 
        label = "Loading model ...", 
        icon = icon("spinner", class="fa-spin")
      )
      insertUI(
        "#sendButton",
        ui = ui_send(btn),
        immediate = TRUE
      )
      
      msgs <- list(list(
        role = "info", 
        content = "Model is loading. It might take some time."
      ))
      insertUI(
        "#messages", 
        ui = ui_messages(msgs), 
        immediate = TRUE
      )
      
      return()
    } 
    
    # now we can use the loop for everyone
    btn <- if (context$generating == "error") {
      list(class = "btn-secondary disabled", label = "Generating error ...")
    } else if (context$generating) {
      list(class = "btn-secondary disabled", label = "Generating response ...", 
           icon = icon("spinner", class="fa-spin"))
    } else {
      list(class = "btn-primary", label = "Send")
    }
    
    output$sendButton <- renderUI({
      actionButton("send", width = "100%", label = btn$label, class = btn$class,
                   icon = btn$icon)
    })
  })
  
  observeEvent(context$messages, priority = 10, {
    # this observer generates and updates the messages list
    output$messages <- renderUI({
      ui_messages(context$messages)
    })
  })
  
  observeEvent(context$idxs, {
    # this observer is responsible for actually generating text by calling the
    # model that is loaded in `sess`. it takes the context$idxs to generate new
    # text and updates it once it's done. It also appends the last message with
    # the newly generated token
    context$idxs %>% 
      sess$generate() %>% 
      promises::then(
        onFulfilled = function(id) {
          if (id %in% c(50278L, 50279L, 50277L, 1L, 0L)) {
            context$generating <- FALSE
            context$n_tokens <- 0
            return() # special tokens that stop generation.
          }
          
          # update last message with the newly generated token
          messages <- context$messages
          new_msg <- paste0(
            messages[[length(messages)]]$content, 
            sess$tok$decode(id)
          )
    
          messages[[length(messages)]]$content <- new_msg
          context$messages <- messages
          
          # update the token counter
          context$n_tokens <- context$n_tokens + 1
          
          if (context$n_tokens > max_n_tokens) {
            context$generating <- FALSE
            context$n_tokens <- 0
            return() # we already generated enough tokens
          }
          
          context$idxs <- c(context$idxs, id)  
        },
        onRejected = function(x) {
          # if there was a generation error, we post a message in the message
          # list with a error role
          msg <- list(
            role = "error", 
            content = paste0("Error generating token.", x)
          )
          context$messages <- append(context$messages, list(msg))
          
          # we also say that we are no longer generating, by setting another
          # value for the `generating` "flag"
          context$generating <- "error"
        }
      )
    NULL
  })
}

ui_messages <- function(messages) {
  emojis <- c(user = "🤗", assistant = "🤖", info = "📣", error = "😭")
  msg_cards <- messages %>% 
    lapply(function(msg) {
      card(style="margin-bottom:5px;", card_body(
        p(paste0(emojis[msg$role], ": ",  msg$content))
      ))
    })
  rlang::exec(card_body, !!!msg_cards, gap = 5, fillable = FALSE)
}

ui_send <- function(btn) {
  actionButton(
    "send",
    icon =  btn$icon,
    width = "100%", 
    label = btn$label, 
    class = btn$class
  )
}

shinyApp(ui, server)