File size: 2,871 Bytes
fce98ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from huggingface_hub import list_models
import streamlit as st
from model import ReplicateModel

import os
import pandas as pd

DATASETS_PATH = 'datasets'

models = {
  'mistral': ReplicateModel('mistralai/mistral-7b-instruct-v0.1:83b6a56e7c828e667f21fd596c338fd4f0039b46bcfa18d973e8e70e455fda70'),
}

prompts = {
  'simple_prompt':
    '''
    I have topic that is described by the following keywords: [KEYWORDS]

    Based on the information above, extract a short topic label in the following format:
    topic: <topic label>
    '''
  # 'custom_prompt': ''
}

topicsets = {
  'lda_poe_topics': os.path.join(DATASETS_PATH, 'lda_poe_topics.csv'),
}

@st.cache_data(show_spinner=False)
def get_available_models():
  # return [model.modelId for model in list_models(author='textminr')]
  return models.keys()

@st.cache_resource(show_spinner='Loading model...')
def load_model(model_name: str):
  # model = AutoGPTQForCausalLM.from_quantized(model_name, device_map='auto')
  # return pipeline('text-generation', model=model, tokenizer=model_name)
  return models[model_name].load()

st.set_page_config(page_title='TL playground', page_icon='πŸš€', layout='wide')
st.title('πŸš€ Topic Labelling playground')

percentage_width_main = 70
st.markdown(
    f'''<style>
    @media only screen and (min-width: 1500px) {{
      .appview-container .main .block-container{{
        max-width: {percentage_width_main}%;
      }}
    }}
    </style>
    ''',
    unsafe_allow_html=True,
)

col1, col2 = st.columns(2, gap='medium')

sel_model_name = col1.selectbox('Select a model', models, index=None, placeholder='Select a model')
if sel_model_name:
  model = load_model(sel_model_name)

sel_dataset_name = col1.selectbox('Select a dataset', topicsets.keys(), index=None)
if sel_dataset_name:
  sel_dataset = pd.read_csv(topicsets[sel_dataset_name], header=None)
  col1.dataframe(sel_dataset)

  sel_row_index = col1.selectbox('Select a row', sel_dataset.index)

sel_prompt = col2.selectbox('Select a prompt', prompts.keys())
if sel_prompt != 'custom_prompt':
  col2.code(prompts[sel_prompt], language='text')
  sel_prompt_text = prompts[sel_prompt]
else:
  sel_prompt_text = st.text_area('Custom prompt', height=200)
  col2.caption('Make sure to use "[KEYWORDS]" to indicate where the keywords should be inserted.')

btn_generate = col2.button('Generate', disabled=(sel_model_name is None or sel_dataset_name is None))
if btn_generate:
  keywords = ','.join(sel_dataset.iloc[sel_row_index].tolist()[1:])

  placeholder = col2.empty()
  with placeholder, st.spinner('Generating...'):
    prompt = sel_prompt_text.replace('[KEYWORDS]', keywords)
    # result = model(prompt, max_new_tokens=100, return_full_text=False)[0]['generated_text']
    result = model.generate(prompt)

  message = col2.chat_message("ai")
  message.write(result)
  message.caption('Keywords: ' + keywords)