library(shiny) library(bslib) library(minhub) library(magrittr) source("model-session.R") repo <- "EleutherAI/pythia-70m" #repo <- "stabilityai/stablelm-tuned-alpha-3b" repo <- Sys.getenv("MODEL_REPO", unset = repo) sess <- model_session$new() max_n_tokens <- 100 system_prompt <- "<|SYSTEM|># StableLM Tuned (Alpha version) - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. - StableLM will refuse to participate in anything that could harm a human. " ui <- page_fillable( theme = bs_theme(bootswatch = "minty"), card( height="90%", heights_equal = "row", width = 1, fillable = FALSE, uiOutput("messages") ), layout_column_wrap( width = 1/2, textInput("prompt", label = NULL, width="100%"), uiOutput("sendButton") ) ) server <- function(input, output, session) { # context for the observers that load the model in the background session # it also handles reloads loading <- reactiveValues( model = NULL, reload = NULL ) # Observer used at app startup time to allow using the 'Send' button once the # model has been loaded. observeEvent(loading$reload, ignoreInit = FALSE, ignoreNULL = FALSE, priority = 0, { # the model is already loaded, we just make sure that we propagate this # by setting generating to FALSE if (!is.null(sess$is_loaded) && sess$is_loaded) { context$generating <- FALSE return() } # the model isn't loaded and no task is trying to load it, so we start a new # task to load it if (is.null(sess$is_loaded)) { cat("Started loading model ....", "\n") loading$model <- sess$load_model(repo) sess$is_loaded <- FALSE # not yet loaded, but loading } else { # the model is loading, but this is handled by another session. We should # come back to this observer later to enable the send button once model # is loaded. invalidateLater(100, session) return() } # this runs for the cases where sess$is_loaded was NULL # ie there was no model currently loading. m <- loading$model %>% promises::then( onFulfilled = function(x) { Sys.sleep(5) cat("Model has been loaded!", "\n") context$generating <- FALSE sess$is_loaded <- TRUE TRUE }, onRejected = function(x) { context$generating <- "error" msg <- list( role = "error", content = paste0("Error loading the model:\n", as.character(x)) ) context$messages <- append(context$messages, list(msg)) # setup for retry! sess$is_loaded <- NULL # means failure! sess$sess <- NULL if (loading$reload < 10) { Sys.sleep(5) loading$reload <- loading$reload + 1 } FALSE }) loading$model <- m }) # context for generating messages context <- reactiveValues( generating = "loading", # a flag indicating if we are still generating tokens idxs = NULL, # the current sequence of tokens n_tokens = 0, # number of tokens already generated messages = list() ) observeEvent(input$send, ignoreInit = TRUE, { # the is the observer for send message action button that triggers the rest # of the reactions. # if the prompt is empty, there's nothing to do if (is.null(input$prompt) || input$prompt == "") { return() } # the user clicked 'send' and the prompt is not empty: # we will enter in generation mode context$generating <- TRUE # we add the user message into the messages list context$messages <- append( context$messages, list(list(role = "user", content = input$prompt)) ) # ... and the start of the assistant message context$messages <- append( context$messages, list(list(role = "assistant", content = "")) ) # we also update the idxs context value with the newly added prompt # in case, this is the first send call, we also need to add the system # prompt if (is.null(context$idxs)) { context$idxs <- sess$tok$encode(system_prompt)$ids } # we now append the prompt. the prompt is wrapped around special tokens # for generation: prompt <- paste0("<|USER|>", input$prompt, "<|ASSISTANT|>") context$idxs <- c(context$idxs, sess$tok$encode(prompt)$ids) cat("Tokens in context: ", length(context$idxs), "\n") }) observeEvent(context$generating, priority = 10, { # this controls the state of the send button. # if generating is TRUE we want it to be disabled, otherwise it's enabled # if generating is `NULL`, then the model is not yet loaded # in the first startup we have to insert the send button directly, otherwise # it would exist after all observers ran, including the one that loads the # model if (context$generating == "loading") { btn <- list( class = "btn-secondary disabled", label = "Loading model ...", icon = icon("spinner", class="fa-spin") ) insertUI( "#sendButton", ui = ui_send(btn), immediate = TRUE ) msgs <- list(list( role = "info", content = "Model is loading. It might take some time." )) insertUI( "#messages", ui = ui_messages(msgs), immediate = TRUE ) return() } # now we can use the loop for everyone btn <- if (context$generating == "error") { list(class = "btn-secondary disabled", label = "Generating error ...") } else if (context$generating) { list(class = "btn-secondary disabled", label = "Generating response ...", icon = icon("spinner", class="fa-spin")) } else { list(class = "btn-primary", label = "Send") } output$sendButton <- renderUI({ actionButton("send", width = "100%", label = btn$label, class = btn$class, icon = btn$icon) }) }) observeEvent(context$messages, priority = 10, { # this observer generates and updates the messages list output$messages <- renderUI({ ui_messages(context$messages) }) }) observeEvent(context$idxs, { # this observer is responsible for actually generating text by calling the # model that is loaded in `sess`. it takes the context$idxs to generate new # text and updates it once it's done. It also appends the last message with # the newly generated token context$idxs %>% sess$generate() %>% promises::then( onFulfilled = function(id) { if (id %in% c(50278L, 50279L, 50277L, 1L, 0L)) { context$generating <- FALSE context$n_tokens <- 0 return() # special tokens that stop generation. } # update last message with the newly generated token messages <- context$messages new_msg <- paste0( messages[[length(messages)]]$content, sess$tok$decode(id) ) messages[[length(messages)]]$content <- new_msg context$messages <- messages # update the token counter context$n_tokens <- context$n_tokens + 1 if (context$n_tokens > max_n_tokens) { context$generating <- FALSE context$n_tokens <- 0 return() # we already generated enough tokens } context$idxs <- c(context$idxs, id) }, onRejected = function(x) { # if there was a generation error, we post a message in the message # list with a error role msg <- list( role = "error", content = paste0("Error generating token.", x) ) context$messages <- append(context$messages, list(msg)) # we also say that we are no longer generating, by setting another # value for the `generating` "flag" context$generating <- "error" } ) NULL }) } ui_messages <- function(messages) { emojis <- c(user = "🤗", assistant = "🤖", info = "📣", error = "😭") msg_cards <- messages %>% lapply(function(msg) { card(style="margin-bottom:5px;", card_body( p(paste0(emojis[msg$role], ": ", msg$content)) )) }) rlang::exec(card_body, !!!msg_cards, gap = 5, fillable = FALSE) } ui_send <- function(btn) { actionButton( "send", icon = btn$icon, width = "100%", label = btn$label, class = btn$class ) } shinyApp(ui, server)