Spaces:
Runtime error
Runtime error
library(shiny) | |
library(bslib) | |
library(minhub) | |
library(magrittr) | |
source("model-session.R") | |
repo <- "EleutherAI/pythia-70m" | |
repo <- Sys.getenv("MODEL_REPO", unset = repo) | |
sess <- model_session$new() | |
max_n_tokens <- 100 | |
system_prompt = "<|SYSTEM|># StableLM Tuned (Alpha version) | |
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. | |
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. | |
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. | |
- StableLM will refuse to participate in anything that could harm a human. | |
" | |
ui <- page_fillable( | |
theme = bs_theme(bootswatch = "minty"), | |
shinyjs::useShinyjs(), | |
card( | |
height="90%", | |
heights_equal = "row", | |
width = 1, | |
fillable = FALSE, | |
card_body(id = "messages", gap = 5, fillable = FALSE) | |
), | |
layout_column_wrap( | |
width = 1/2, | |
textInput("prompt", label = NULL, width="100%"), | |
actionButton("send", "Send", width = "100%") | |
) | |
) | |
server <- function(input, output, session) { | |
prompt <- reactiveVal(value = system_prompt) | |
n_tokens <- reactiveVal(value = 0) | |
msg_id <- reactiveVal(value = 0) | |
observeEvent(input$send, { | |
if (is.null(input$prompt) || input$prompt == "") { | |
return() | |
} | |
shinyjs::disable("send") | |
updateActionButton(inputId = "send", label = "Waiting for model...") | |
insert_message(msg_id, as.character(glue::glue("π€: {input$prompt}"))) | |
# we modify the prompt to trigger the 'next_token' reactive | |
prompt(paste0(prompt(), "<|USER|>", input$prompt, "<|ASSISTANT|>")) | |
}) | |
next_token <- eventReactive(prompt(), ignoreInit = TRUE, { | |
prompt() %>% | |
sess$generate() %>% | |
promises::then( | |
onFulfilled = function(x) {x}, | |
onRejected = function(x) { | |
insert_message(msg_id, paste0("π Error generating token.", as.character(x))) | |
updateActionButton(inputId = "send", label = "Failing generation. Contact admin.") | |
NULL | |
} | |
) | |
}) | |
observeEvent(next_token(), { | |
tok <- next_token() | |
n_tokens(n_tokens() + 1) | |
tok %>% promises::then(function(tok) { | |
if (n_tokens() == 1) { | |
insert_message(msg_id, paste0("π€: ", tok), append = FALSE) | |
} else { | |
insert_message(msg_id, tok, append = TRUE) | |
} | |
if (tok != "" && n_tokens() < max_n_tokens) { | |
prompt(paste0(prompt(), tok)) | |
} else { | |
shinyjs::enable("send") | |
updateActionButton(inputId = "send", label = "Send") | |
n_tokens(0) | |
} | |
}) | |
}) | |
# Observer used at app startup time to allow using the 'Send' button once the | |
# model has been loaded. | |
model_loaded <- reactiveVal() | |
event_reload <- reactiveVal(val = 0) | |
observeEvent(event_reload(), ignoreNULL=FALSE, { | |
# the model is already loaded, we just make sure the send button is enabled | |
if (!is.null(sess$is_loaded) && sess$is_loaded) { | |
shinyjs::enable("send") | |
updateActionButton(inputId = "send", label = "Send") | |
return() | |
} | |
# the model isn't loaded, this we disable the send button and | |
# show that we are loading the model | |
shinyjs::disable("send") | |
updateActionButton(inputId = "send", label = "Loading the model...") | |
# the model isn't loaded and no task is trying to load it, so we start a new | |
# task to load it | |
if (is.null(sess$is_loaded)) { | |
cat("Started loading model ....", "\n") | |
model_loaded(sess$load_model(repo)) | |
sess$is_loaded <- FALSE # not yet loaded, but loading | |
} else { | |
# the model is loading, but this is handled by another session. We should | |
# come back to this observer later to enable the send button once model | |
# is loaded. | |
invalidateLater(5000, session) | |
return() | |
} | |
# this runs for the cases where sess$is_loaded was NULL | |
# ie there was no model currently loading. | |
m <- model_loaded() %>% | |
promises::then(onFulfilled = function(x) { | |
cat("Model has been loaded!", "\n") | |
shinyjs::enable("send") | |
updateActionButton(inputId = "send", label = "Send") | |
sess$is_loaded <- TRUE | |
TRUE | |
}, onRejected = function(x) { | |
shinyjs::disable("send") | |
insert_message(msg_id, paste0("π Error loading the model:\n", as.character(x))) | |
sess$is_loaded <- NULL # means failure! | |
sess$sess <- NULL | |
if (event_reload() < 10) { | |
Sys.sleep(5) | |
event_reload(event_reload() + 1) | |
} | |
FALSE | |
}) | |
model_loaded(m) | |
}) | |
} | |
insert_message <- function(message_id, msg, append = FALSE) { | |
if (!append) { | |
id <- message_id() + 1 | |
message_id(id) | |
insertUI( | |
"#messages", | |
"beforeEnd", | |
immediate = TRUE, | |
ui = card(style="margin-bottom:5px;", card_body( | |
p(id = paste0("msg-",id), msg) | |
)) | |
) | |
} else { | |
id <- message_id() | |
shinyjs::runjs(glue::glue( | |
"document.getElementById('msg-{id}').textContent += '{msg}'" | |
)) | |
} | |
# scroll to bottom | |
shinyjs::runjs("var elem = document.getElementById('messages'); elem.scrollTop = elem.scrollHeight;") | |
id | |
} | |
shinyApp(ui, server) | |