Spaces:
Runtime error
Runtime error
File size: 8,943 Bytes
3c0d933 ca30460 b73694e 8a3fba7 ca30460 13c1f55 173d645 93d628c 8a3fba7 3c0d933 8a3fba7 13c1f55 8a3fba7 3c0d933 8a3fba7 640a26a 8a3fba7 640a26a 3c0d933 640a26a 8a3fba7 640a26a c2274c7 640a26a 8f4e035 640a26a 8f4e035 c2274c7 0313f74 c29126b 640a26a 33145ed 8f4e035 c2274c7 8f4e035 c29126b 896d280 8f4e035 640a26a c2274c7 640a26a 8a3fba7 640a26a 0313f74 640a26a 8a3fba7 640a26a c2274c7 640a26a c2274c7 640a26a c2274c7 640a26a b1231d6 640a26a b1231d6 640a26a c2274c7 640a26a 8a3fba7 c2274c7 3c0d933 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 |
library(shiny)
library(bslib)
library(minhub)
library(magrittr)
source("model-session.R")
repo <- "EleutherAI/pythia-70m"
repo <- "stabilityai/stablelm-tuned-alpha-3b"
repo <- Sys.getenv("MODEL_REPO", unset = repo)
sess <- model_session$new()
max_n_tokens <- 100
system_prompt <- "<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"
ui <- page_fillable(
theme = bs_theme(bootswatch = "minty"),
card(
height="90%",
heights_equal = "row",
width = 1,
fillable = FALSE,
uiOutput("messages")
),
layout_column_wrap(
width = 1/2,
textInput("prompt", label = NULL, width="100%"),
uiOutput("sendButton")
)
)
server <- function(input, output, session) {
# context for the observers that load the model in the background session
# it also handles reloads
loading <- reactiveValues(
model = NULL,
reload = NULL
)
# Observer used at app startup time to allow using the 'Send' button once the
# model has been loaded.
observeEvent(loading$reload, ignoreInit = FALSE, ignoreNULL = FALSE, priority = 0, {
# the model is already loaded, we just make sure that we propagate this
# by setting generating to FALSE
if (!is.null(sess$is_loaded) && sess$is_loaded) {
context$generating <- FALSE
return()
}
# the model isn't loaded and no task is trying to load it, so we start a new
# task to load it
if (is.null(sess$is_loaded)) {
cat("Started loading model ....", "\n")
loading$model <- sess$load_model(repo)
sess$is_loaded <- FALSE # not yet loaded, but loading
} else {
# the model is loading, but this is handled by another session. We should
# come back to this observer later to enable the send button once model
# is loaded.
invalidateLater(100, session)
return()
}
# this runs for the cases where sess$is_loaded was NULL
# ie there was no model currently loading.
m <- loading$model %>%
promises::then(
onFulfilled = function(x) {
Sys.sleep(5)
cat("Model has been loaded!", "\n")
context$generating <- FALSE
sess$is_loaded <- TRUE
TRUE
},
onRejected = function(x) {
context$generating <- "error"
msg <- list(
role = "error",
content = paste0("Error loading the model:\n", as.character(x))
)
context$messages <- append(context$messages, list(msg))
# setup for retry!
sess$is_loaded <- NULL # means failure!
sess$sess <- NULL
if (loading$reload < 10) {
Sys.sleep(5)
loading$reload <- loading$reload + 1
}
FALSE
})
loading$model <- m
})
# context for generating messages
context <- reactiveValues(
generating = "loading", # a flag indicating if we are still generating tokens
idxs = NULL, # the current sequence of tokens
n_tokens = 0, # number of tokens already generated
messages = list()
)
observeEvent(input$send, ignoreInit = TRUE, {
# the is the observer for send message action button that triggers the rest
# of the reactions.
# if the prompt is empty, there's nothing to do
if (is.null(input$prompt) || input$prompt == "") {
return()
}
# the user clicked 'send' and the prompt is not empty:
# we will enter in generation mode
context$generating <- TRUE
# we add the user message into the messages list
context$messages <- append(
context$messages,
list(list(role = "user", content = input$prompt))
)
# ... and the start of the assistant message
context$messages <- append(
context$messages,
list(list(role = "assistant", content = ""))
)
# we also update the idxs context value with the newly added prompt
# in case, this is the first send call, we also need to add the system
# prompt
if (is.null(context$idxs)) {
context$idxs <- sess$tok$encode(system_prompt)$ids
}
# we now append the prompt. the prompt is wrapped around special tokens
# for generation:
prompt <- paste0("<|USER|>", input$prompt, "<|ASSISTANT|>")
context$idxs <- c(context$idxs, sess$tok$encode(prompt)$ids)
cat("Tokens in context: ", length(context$idxs), "\n")
})
observeEvent(context$generating, priority = 10, {
# this controls the state of the send button.
# if generating is TRUE we want it to be disabled, otherwise it's enabled
# if generating is `NULL`, then the model is not yet loaded
# in the first startup we have to insert the send button directly, otherwise
# it would exist after all observers ran, including the one that loads the
# model
if (context$generating == "loading") {
btn <- list(
class = "btn-secondary disabled",
label = "Loading model ...",
icon = icon("spinner", class="fa-spin")
)
insertUI(
"#sendButton",
ui = ui_send(btn),
immediate = TRUE
)
msgs <- list(list(
role = "info",
content = "Model is loading. It might take some time."
))
insertUI(
"#messages",
ui = ui_messages(msgs),
immediate = TRUE
)
return()
}
# now we can use the loop for everyone
btn <- if (context$generating == "error") {
list(class = "btn-secondary disabled", label = "Generating error ...")
} else if (context$generating) {
list(class = "btn-secondary disabled", label = "Generating response ...",
icon = icon("spinner", class="fa-spin"))
} else {
list(class = "btn-primary", label = "Send")
}
output$sendButton <- renderUI({
actionButton("send", width = "100%", label = btn$label, class = btn$class,
icon = btn$icon)
})
})
observeEvent(context$messages, priority = 10, {
# this observer generates and updates the messages list
output$messages <- renderUI({
ui_messages(context$messages)
})
})
observeEvent(context$idxs, {
# this observer is responsible for actually generating text by calling the
# model that is loaded in `sess`. it takes the context$idxs to generate new
# text and updates it once it's done. It also appends the last message with
# the newly generated token
context$idxs %>%
sess$generate() %>%
promises::then(
onFulfilled = function(id) {
if (id %in% c(50278L, 50279L, 50277L, 1L, 0L)) {
context$generating <- FALSE
context$n_tokens <- 0
return() # special tokens that stop generation.
}
# update last message with the newly generated token
messages <- context$messages
new_msg <- paste0(
messages[[length(messages)]]$content,
sess$tok$decode(id)
)
messages[[length(messages)]]$content <- new_msg
context$messages <- messages
# update the token counter
context$n_tokens <- context$n_tokens + 1
if (context$n_tokens > max_n_tokens) {
context$generating <- FALSE
context$n_tokens <- 0
return() # we already generated enough tokens
}
context$idxs <- c(context$idxs, id)
},
onRejected = function(x) {
# if there was a generation error, we post a message in the message
# list with a error role
msg <- list(
role = "error",
content = paste0("Error generating token.", x)
)
context$messages <- append(context$messages, list(msg))
# we also say that we are no longer generating, by setting another
# value for the `generating` "flag"
context$generating <- "error"
}
)
NULL
})
}
ui_messages <- function(messages) {
emojis <- c(user = "🤗", assistant = "🤖", info = "📣", error = "😭")
msg_cards <- messages %>%
lapply(function(msg) {
card(style="margin-bottom:5px;", card_body(
p(paste0(emojis[msg$role], ": ", msg$content))
))
})
rlang::exec(card_body, !!!msg_cards, gap = 5, fillable = FALSE)
}
ui_send <- function(btn) {
actionButton(
"send",
icon = btn$icon,
width = "100%",
label = btn$label,
class = btn$class
)
}
shinyApp(ui, server)
|