gptneox-chat / app.R
dfalbel's picture
some more organization
c2274c7 unverified
raw
history blame
8.97 kB
library(shiny)
library(bslib)
library(minhub)
library(magrittr)
source("model-session.R")
repo <- "EleutherAI/pythia-70m"
repo <- "stabilityai/stablelm-tuned-alpha-3b"
repo <- Sys.getenv("MODEL_REPO", unset = repo)
sess <- model_session$new()
max_n_tokens <- 100
system_prompt <- "<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"
ui <- page_fillable(
theme = bs_theme(bootswatch = "minty"),
shinyjs::useShinyjs(),
card(
height="90%",
heights_equal = "row",
width = 1,
fillable = FALSE,
uiOutput("messages")
),
layout_column_wrap(
width = 1/2,
textInput("prompt", label = NULL, width="100%"),
uiOutput("sendButton")
)
)
server <- function(input, output, session) {
# context for the observers that load the model in the background session
# it also handles reloads
loading <- reactiveValues(
model = NULL,
reload = NULL
)
# Observer used at app startup time to allow using the 'Send' button once the
# model has been loaded.
observeEvent(loading$reload, ignoreInit = FALSE, ignoreNULL = FALSE, priority = 0, {
# the model is already loaded, we just make sure that we propagate this
# by setting generating to FALSE
if (!is.null(sess$is_loaded) && sess$is_loaded) {
context$generating <- FALSE
return()
}
# the model isn't loaded and no task is trying to load it, so we start a new
# task to load it
if (is.null(sess$is_loaded)) {
cat("Started loading model ....", "\n")
loading$model <- sess$load_model(repo)
sess$is_loaded <- FALSE # not yet loaded, but loading
} else {
# the model is loading, but this is handled by another session. We should
# come back to this observer later to enable the send button once model
# is loaded.
invalidateLater(100, session)
return()
}
# this runs for the cases where sess$is_loaded was NULL
# ie there was no model currently loading.
m <- loading$model %>%
promises::then(
onFulfilled = function(x) {
Sys.sleep(5)
cat("Model has been loaded!", "\n")
context$generating <- FALSE
sess$is_loaded <- TRUE
TRUE
},
onRejected = function(x) {
context$generating <- "error"
msg <- list(
role = "error",
content = paste0("Error loading the model:\n", as.character(x))
)
context$messages <- append(context$messages, list(msg))
# setup for retry!
sess$is_loaded <- NULL # means failure!
sess$sess <- NULL
if (loading$reload < 10) {
Sys.sleep(5)
loading$reload <- loading$reload + 1
}
FALSE
})
loading$model <- m
})
# context for generating messages
context <- reactiveValues(
generating = "loading", # a flag indicating if we are still generating tokens
idxs = NULL, # the current sequence of tokens
n_tokens = 0, # number of tokens already generated
messages = list()
)
observeEvent(input$send, ignoreInit = TRUE, {
# the is the observer for send message action button that triggers the rest
# of the reactions.
# if the prompt is empty, there's nothing to do
if (is.null(input$prompt) || input$prompt == "") {
return()
}
# the user clicked 'send' and the prompt is not empty:
# we will enter in generation mode
context$generating <- TRUE
# we add the user message into the messages list
context$messages <- append(
context$messages,
list(list(role = "user", content = input$prompt))
)
# ... and the start of the assistant message
context$messages <- append(
context$messages,
list(list(role = "assistant", content = ""))
)
# we also update the idxs context value with the newly added prompt
# in case, this is the first send call, we also need to add the system
# prompt
if (is.null(context$idxs)) {
context$idxs <- sess$tok$encode(system_prompt)$ids
}
# we now append the prompt. the prompt is wrapped around special tokens
# for generation:
prompt <- paste0("<|USER|>", input$prompt, "<|ASSISTANT|>")
context$idxs <- c(context$idxs, sess$tok$encode(prompt)$ids)
cat("Tokens in context: ", length(context$idxs), "\n")
})
observeEvent(context$generating, priority = 10, {
# this controls the state of the send button.
# if generating is TRUE we want it to be disabled, otherwise it's enabled
# if generating is `NULL`, then the model is not yet loaded
# in the first startup we have to insert the send button directly, otherwise
# it would exist after all observers ran, including the one that loads the
# model
if (context$generating == "loading") {
btn <- list(
class = "btn-secondary disabled",
label = "Loading model ...",
icon = icon("spinner", class="fa-spin")
)
insertUI(
"#sendButton",
ui = ui_send(btn),
immediate = TRUE
)
msgs <- list(list(
role = "info",
content = "Model is loading. It might take some time."
))
insertUI(
"#messages",
ui = ui_messages(msgs),
immediate = TRUE
)
return()
}
# now we can use the loop for everyone
btn <- if (context$generating == "error") {
list(class = "btn-secondary disabled", label = "Generating error ...")
} else if (context$generating) {
list(class = "btn-secondary disabled", label = "Generating response ...",
icon = icon("spinner", class="fa-spin"))
} else {
list(class = "btn-primary", label = "Send")
}
output$sendButton <- renderUI({
actionButton("send", width = "100%", label = btn$label, class = btn$class,
icon = btn$icon)
})
})
observeEvent(context$messages, priority = 10, {
# this observer generates and updates the messages list
output$messages <- renderUI({
ui_messages(context$messages)
})
})
observeEvent(context$idxs, {
# this observer is responsible for actually generating text by calling the
# model that is loaded in `sess`. it takes the context$idxs to generate new
# text and updates it once it's done. It also appends the last message with
# the newly generated token
context$idxs %>%
sess$generate() %>%
promises::then(
onFulfilled = function(id) {
if (id %in% c(50278L, 50279L, 50277L, 1L, 0L)) {
context$generating <- FALSE
context$n_tokens <- 0
return() # special tokens that stop generation.
}
# update last message with the newly generated token
messages <- context$messages
new_msg <- paste0(
messages[[length(messages)]]$content,
sess$tok$decode(id)
)
messages[[length(messages)]]$content <- new_msg
context$messages <- messages
# update the token counter
context$n_tokens <- context$n_tokens + 1
if (context$n_tokens > max_n_tokens) {
context$generating <- FALSE
context$n_tokens <- 0
return() # we already generated enough tokens
}
context$idxs <- c(context$idxs, id)
},
onRejected = function(x) {
# if there was a generation error, we post a message in the message
# list with a error role
msg <- list(
role = "error",
content = paste0("Error generating token.", x)
)
context$messages <- append(context$messages, list(msg))
# we also say that we are no longer generating, by setting another
# value for the `generating` "flag"
context$generating <- "error"
}
)
NULL
})
}
ui_messages <- function(messages) {
emojis <- c(user = "πŸ€—", assistant = "πŸ€–", info = "πŸ“£", error = "😭")
msg_cards <- messages %>%
lapply(function(msg) {
card(style="margin-bottom:5px;", card_body(
p(paste0(emojis[msg$role], ": ", msg$content))
))
})
rlang::exec(card_body, !!!msg_cards, gap = 5, fillable = FALSE)
}
ui_send <- function(btn) {
actionButton(
"send",
icon = btn$icon,
width = "100%",
label = btn$label,
class = btn$class
)
}
shinyApp(ui, server)