KennethTM's picture
Changed default generation settings
6ee4621
import streamlit as st
import galai as gal
#https://github.com/paperswithcode/galai/blob/main/notebooks/Introduction%20to%20Galactica%20Models.ipynb
@st.cache_resource
def load_model(model_name):
model = gal.load_model(model_name, num_gpus=0) #, dtype=torch.float16
return model
if 'text' not in st.session_state:
st.session_state['text'] = ""
def generate_text():
st.session_state['text'] = model.generate(input_text=st.session_state.editor, penalty_alpha=penalty_alpha, top_k=top_k, max_new_tokens=max_new_tokens, new_doc=new_doc)
def suggest_reference():
st.session_state['text'] = st.session_state.editor + " " + model.generate_reference(input_text=st.session_state.editor)
#Sidebar
st.sidebar.markdown("### Select model")
choose_model = st.sidebar.selectbox("Size", ["mini", "base"])
model = load_model(choose_model)
st.sidebar.markdown("### Text generation settings")
max_new_tokens = st.sidebar.slider("Max new tokens", value=10, min_value = 10, max_value = 100, step=10)
penalty_alpha = st.sidebar.slider("Alpha penalty", value = 0.6, min_value = 0.0, max_value=2.0, step=0.1)
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=10, value = 2)
new_doc = st.sidebar.checkbox("New document", value=True)
#Main
st.markdown(
'''
# Scientific writing assistant
## GALACTICA model
The [GALACTICA models](https://www.galactica.org) have been training on a large corpus of scientific data (see also the [GitHub repository](https://github.com/paperswithcode/galai)). Try out the two smaller models here and how they can be used to generate scientific text and suggest references.
Write text in the editor and push **TAB** or **CTRL+ENTER** to generate text.
Settings for model size (mini = 125 M and base = 1.3 B parameters) and text generation can be managed using from the left margin.
## Suggest citations
Use **Add citation** button to suggest and insert a citation into the text editor. The citation format is *Title, First author*.
''')
add_ref = st.button("Add citation", on_click=suggest_reference)
text_editor = st.text_area("(not shown)", st.session_state['text'], height = 500, key="editor", on_change=generate_text, label_visibility="collapsed")