Spaces:
Runtime error
Runtime error
library(shiny) | |
library(bslib) | |
library(minhub) | |
library(magrittr) | |
source("model-session.R") | |
repo <- "EleutherAI/pythia-70m" | |
repo <- "stabilityai/stablelm-tuned-alpha-3b" | |
repo <- Sys.getenv("MODEL_REPO", unset = repo) | |
sess <- model_session$new() | |
max_n_tokens <- 100 | |
system_prompt <- "<|SYSTEM|># StableLM Tuned (Alpha version) | |
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. | |
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. | |
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. | |
- StableLM will refuse to participate in anything that could harm a human. | |
" | |
ui <- page_fillable( | |
theme = bs_theme(bootswatch = "minty"), | |
shinyjs::useShinyjs(), | |
card( | |
height="90%", | |
heights_equal = "row", | |
width = 1, | |
fillable = FALSE, | |
uiOutput("messages") | |
), | |
layout_column_wrap( | |
width = 1/2, | |
textInput("prompt", label = NULL, width="100%"), | |
uiOutput("sendButton") | |
) | |
) | |
server <- function(input, output, session) { | |
# context for the observers that load the model in the background session | |
# it also handles reloads | |
loading <- reactiveValues( | |
model = NULL, | |
reload = NULL | |
) | |
# Observer used at app startup time to allow using the 'Send' button once the | |
# model has been loaded. | |
observeEvent(loading$reload, ignoreInit = FALSE, ignoreNULL = FALSE, priority = 0, { | |
# the model is already loaded, we just make sure that we propagate this | |
# by setting generating to FALSE | |
if (!is.null(sess$is_loaded) && sess$is_loaded) { | |
context$generating <- FALSE | |
return() | |
} | |
# the model isn't loaded and no task is trying to load it, so we start a new | |
# task to load it | |
if (is.null(sess$is_loaded)) { | |
cat("Started loading model ....", "\n") | |
loading$model <- sess$load_model(repo) | |
sess$is_loaded <- FALSE # not yet loaded, but loading | |
} else { | |
# the model is loading, but this is handled by another session. We should | |
# come back to this observer later to enable the send button once model | |
# is loaded. | |
invalidateLater(100, session) | |
return() | |
} | |
# this runs for the cases where sess$is_loaded was NULL | |
# ie there was no model currently loading. | |
m <- loading$model %>% | |
promises::then( | |
onFulfilled = function(x) { | |
Sys.sleep(5) | |
cat("Model has been loaded!", "\n") | |
context$generating <- FALSE | |
sess$is_loaded <- TRUE | |
TRUE | |
}, | |
onRejected = function(x) { | |
context$generating <- "error" | |
msg <- list( | |
role = "error", | |
content = paste0("Error loading the model:\n", as.character(x)) | |
) | |
context$messages <- append(context$messages, list(msg)) | |
# setup for retry! | |
sess$is_loaded <- NULL # means failure! | |
sess$sess <- NULL | |
if (loading$reload < 10) { | |
Sys.sleep(5) | |
loading$reload <- loading$reload + 1 | |
} | |
FALSE | |
}) | |
loading$model <- m | |
}) | |
# context for generating messages | |
context <- reactiveValues( | |
generating = "loading", # a flag indicating if we are still generating tokens | |
idxs = NULL, # the current sequence of tokens | |
n_tokens = 0, # number of tokens already generated | |
messages = list() | |
) | |
observeEvent(input$send, ignoreInit = TRUE, { | |
# the is the observer for send message action button that triggers the rest | |
# of the reactions. | |
# if the prompt is empty, there's nothing to do | |
if (is.null(input$prompt) || input$prompt == "") { | |
return() | |
} | |
# the user clicked 'send' and the prompt is not empty: | |
# we will enter in generation mode | |
context$generating <- TRUE | |
# we add the user message into the messages list | |
context$messages <- append( | |
context$messages, | |
list(list(role = "user", content = input$prompt)) | |
) | |
# ... and the start of the assistant message | |
context$messages <- append( | |
context$messages, | |
list(list(role = "assistant", content = "")) | |
) | |
# we also update the idxs context value with the newly added prompt | |
# in case, this is the first send call, we also need to add the system | |
# prompt | |
if (is.null(context$idxs)) { | |
context$idxs <- sess$tok$encode(system_prompt)$ids | |
} | |
# we now append the prompt. the prompt is wrapped around special tokens | |
# for generation: | |
prompt <- paste0("<|USER|>", input$prompt, "<|ASSISTANT|>") | |
context$idxs <- c(context$idxs, sess$tok$encode(prompt)$ids) | |
cat("Tokens in context: ", length(context$idxs), "\n") | |
}) | |
observeEvent(context$generating, priority = 10, { | |
# this controls the state of the send button. | |
# if generating is TRUE we want it to be disabled, otherwise it's enabled | |
# if generating is `NULL`, then the model is not yet loaded | |
# in the first startup we have to insert the send button directly, otherwise | |
# it would exist after all observers ran, including the one that loads the | |
# model | |
if (context$generating == "loading") { | |
btn <- list( | |
class = "btn-secondary disabled", | |
label = "Loading model ...", | |
icon = icon("spinner", class="fa-spin") | |
) | |
insertUI( | |
"#sendButton", | |
ui = ui_send(btn), | |
immediate = TRUE | |
) | |
msgs <- list(list( | |
role = "info", | |
content = "Model is loading. It might take some time." | |
)) | |
insertUI( | |
"#messages", | |
ui = ui_messages(msgs), | |
immediate = TRUE | |
) | |
return() | |
} | |
# now we can use the loop for everyone | |
btn <- if (context$generating == "error") { | |
list(class = "btn-secondary disabled", label = "Generating error ...") | |
} else if (context$generating) { | |
list(class = "btn-secondary disabled", label = "Generating response ...", | |
icon = icon("spinner", class="fa-spin")) | |
} else { | |
list(class = "btn-primary", label = "Send") | |
} | |
output$sendButton <- renderUI({ | |
actionButton("send", width = "100%", label = btn$label, class = btn$class, | |
icon = btn$icon) | |
}) | |
}) | |
observeEvent(context$messages, priority = 10, { | |
# this observer generates and updates the messages list | |
output$messages <- renderUI({ | |
ui_messages(context$messages) | |
}) | |
}) | |
observeEvent(context$idxs, { | |
# this observer is responsible for actually generating text by calling the | |
# model that is loaded in `sess`. it takes the context$idxs to generate new | |
# text and updates it once it's done. It also appends the last message with | |
# the newly generated token | |
context$idxs %>% | |
sess$generate() %>% | |
promises::then( | |
onFulfilled = function(id) { | |
if (id %in% c(50278L, 50279L, 50277L, 1L, 0L)) { | |
context$generating <- FALSE | |
context$n_tokens <- 0 | |
return() # special tokens that stop generation. | |
} | |
# update last message with the newly generated token | |
messages <- context$messages | |
new_msg <- paste0( | |
messages[[length(messages)]]$content, | |
sess$tok$decode(id) | |
) | |
messages[[length(messages)]]$content <- new_msg | |
context$messages <- messages | |
# update the token counter | |
context$n_tokens <- context$n_tokens + 1 | |
if (context$n_tokens > max_n_tokens) { | |
context$generating <- FALSE | |
context$n_tokens <- 0 | |
return() # we already generated enough tokens | |
} | |
context$idxs <- c(context$idxs, id) | |
}, | |
onRejected = function(x) { | |
# if there was a generation error, we post a message in the message | |
# list with a error role | |
msg <- list( | |
role = "error", | |
content = paste0("Error generating token.", x) | |
) | |
context$messages <- append(context$messages, list(msg)) | |
# we also say that we are no longer generating, by setting another | |
# value for the `generating` "flag" | |
context$generating <- "error" | |
} | |
) | |
NULL | |
}) | |
} | |
ui_messages <- function(messages) { | |
emojis <- c(user = "π€", assistant = "π€", info = "π£", error = "π") | |
msg_cards <- messages %>% | |
lapply(function(msg) { | |
card(style="margin-bottom:5px;", card_body( | |
p(paste0(emojis[msg$role], ": ", msg$content)) | |
)) | |
}) | |
rlang::exec(card_body, !!!msg_cards, gap = 5, fillable = FALSE) | |
} | |
ui_send <- function(btn) { | |
actionButton( | |
"send", | |
icon = btn$icon, | |
width = "100%", | |
label = btn$label, | |
class = btn$class | |
) | |
} | |
shinyApp(ui, server) | |