Spaces:
Runtime error
Runtime error
from huggingface_hub import list_models | |
import streamlit as st | |
from model import ReplicateModel | |
import os | |
import pandas as pd | |
DATASETS_PATH = 'datasets' | |
models = { | |
'mistral': ReplicateModel('mistralai/mistral-7b-instruct-v0.1:83b6a56e7c828e667f21fd596c338fd4f0039b46bcfa18d973e8e70e455fda70'), | |
} | |
prompts = { | |
'simple_prompt': | |
''' | |
I have topic that is described by the following keywords: [KEYWORDS] | |
Based on the information above, extract a short topic label in the following format: | |
topic: <topic label> | |
''' | |
# 'custom_prompt': '' | |
} | |
topicsets = { | |
'lda_poe_topics': os.path.join(DATASETS_PATH, 'lda_poe_topics.csv'), | |
} | |
def get_available_models(): | |
# return [model.modelId for model in list_models(author='textminr')] | |
return models.keys() | |
def load_model(model_name: str): | |
# model = AutoGPTQForCausalLM.from_quantized(model_name, device_map='auto') | |
# return pipeline('text-generation', model=model, tokenizer=model_name) | |
return models[model_name].load() | |
st.set_page_config(page_title='TL playground', page_icon='🚀', layout='wide') | |
st.title('🚀 Topic Labelling playground') | |
percentage_width_main = 70 | |
st.markdown( | |
f'''<style> | |
@media only screen and (min-width: 1500px) {{ | |
.appview-container .main .block-container{{ | |
max-width: {percentage_width_main}%; | |
}} | |
}} | |
</style> | |
''', | |
unsafe_allow_html=True, | |
) | |
col1, col2 = st.columns(2, gap='medium') | |
sel_model_name = col1.selectbox('Select a model', models, index=None, placeholder='Select a model') | |
if sel_model_name: | |
model = load_model(sel_model_name) | |
sel_dataset_name = col1.selectbox('Select a dataset', topicsets.keys(), index=None) | |
if sel_dataset_name: | |
sel_dataset = pd.read_csv(topicsets[sel_dataset_name], header=None) | |
col1.dataframe(sel_dataset) | |
sel_row_index = col1.selectbox('Select a row', sel_dataset.index) | |
sel_prompt = col2.selectbox('Select a prompt', prompts.keys()) | |
if sel_prompt != 'custom_prompt': | |
col2.code(prompts[sel_prompt], language='text') | |
sel_prompt_text = prompts[sel_prompt] | |
else: | |
sel_prompt_text = st.text_area('Custom prompt', height=200) | |
col2.caption('Make sure to use "[KEYWORDS]" to indicate where the keywords should be inserted.') | |
btn_generate = col2.button('Generate', disabled=(sel_model_name is None or sel_dataset_name is None)) | |
if btn_generate: | |
keywords = ','.join(sel_dataset.iloc[sel_row_index].tolist()[1:]) | |
placeholder = col2.empty() | |
with placeholder, st.spinner('Generating...'): | |
prompt = sel_prompt_text.replace('[KEYWORDS]', keywords) | |
# result = model(prompt, max_new_tokens=100, return_full_text=False)[0]['generated_text'] | |
result = model.generate(prompt) | |
message = col2.chat_message("ai") | |
message.write(result) | |
message.caption('Keywords: ' + keywords) | |