tl-playground / app.py
Andreas Sünder
Add files from previous repo
fce98ea
raw
history blame
2.87 kB
from huggingface_hub import list_models
import streamlit as st
from model import ReplicateModel
import os
import pandas as pd
DATASETS_PATH = 'datasets'
models = {
'mistral': ReplicateModel('mistralai/mistral-7b-instruct-v0.1:83b6a56e7c828e667f21fd596c338fd4f0039b46bcfa18d973e8e70e455fda70'),
}
prompts = {
'simple_prompt':
'''
I have topic that is described by the following keywords: [KEYWORDS]
Based on the information above, extract a short topic label in the following format:
topic: <topic label>
'''
# 'custom_prompt': ''
}
topicsets = {
'lda_poe_topics': os.path.join(DATASETS_PATH, 'lda_poe_topics.csv'),
}
@st.cache_data(show_spinner=False)
def get_available_models():
# return [model.modelId for model in list_models(author='textminr')]
return models.keys()
@st.cache_resource(show_spinner='Loading model...')
def load_model(model_name: str):
# model = AutoGPTQForCausalLM.from_quantized(model_name, device_map='auto')
# return pipeline('text-generation', model=model, tokenizer=model_name)
return models[model_name].load()
st.set_page_config(page_title='TL playground', page_icon='🚀', layout='wide')
st.title('🚀 Topic Labelling playground')
percentage_width_main = 70
st.markdown(
f'''<style>
@media only screen and (min-width: 1500px) {{
.appview-container .main .block-container{{
max-width: {percentage_width_main}%;
}}
}}
</style>
''',
unsafe_allow_html=True,
)
col1, col2 = st.columns(2, gap='medium')
sel_model_name = col1.selectbox('Select a model', models, index=None, placeholder='Select a model')
if sel_model_name:
model = load_model(sel_model_name)
sel_dataset_name = col1.selectbox('Select a dataset', topicsets.keys(), index=None)
if sel_dataset_name:
sel_dataset = pd.read_csv(topicsets[sel_dataset_name], header=None)
col1.dataframe(sel_dataset)
sel_row_index = col1.selectbox('Select a row', sel_dataset.index)
sel_prompt = col2.selectbox('Select a prompt', prompts.keys())
if sel_prompt != 'custom_prompt':
col2.code(prompts[sel_prompt], language='text')
sel_prompt_text = prompts[sel_prompt]
else:
sel_prompt_text = st.text_area('Custom prompt', height=200)
col2.caption('Make sure to use "[KEYWORDS]" to indicate where the keywords should be inserted.')
btn_generate = col2.button('Generate', disabled=(sel_model_name is None or sel_dataset_name is None))
if btn_generate:
keywords = ','.join(sel_dataset.iloc[sel_row_index].tolist()[1:])
placeholder = col2.empty()
with placeholder, st.spinner('Generating...'):
prompt = sel_prompt_text.replace('[KEYWORDS]', keywords)
# result = model(prompt, max_new_tokens=100, return_full_text=False)[0]['generated_text']
result = model.generate(prompt)
message = col2.chat_message("ai")
message.write(result)
message.caption('Keywords: ' + keywords)