import pandas as pd
import streamlit as st
from langchain import PromptTemplate, HuggingFaceHub, LLMChain
from langchain.llms import OpenAI
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os
import re
def extract_positive_negative(text):
pattern = r'\b(?:positive|negative)\b'
result = re.findall(pattern, text)
return result
def classify_text(text, llm_chain, api):
if api == "HuggingFace":
classification =
elif api == "OpenAI":
classification =
classification = re.sub(r'\s', '', classification)
return classification.lower()
def classify_csv(df, llm_chain, api):
df["label_gold"] = df["label"]
del df["label"]
df["label_pred"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api)
return df
def classify_csv_zero(zero_file, llm_chain, api):
df = pd.read_csv(zero_file, sep=';')
df["label"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api)
return df
def evaluate_performance(df):
merged_df = df
correct_preds = sum(merged_df["label_gold"] == merged_df["label_pred"])
total_preds = len(merged_df)
percentage_overlap = correct_preds / total_preds * 100
return percentage_overlap
def display_home():
st.write("Please select an API and a model to classify the text. We currently support HuggingFace and OpenAI.")
api = st.selectbox("Select an API", ["HuggingFace", "OpenAI"])
if api == "HuggingFace":
model = st.selectbox("Select a model", ["google/flan-t5-xl", "databricks/dolly-v1-6b"])
api_key_hug = st.text_input("HuggingFace API Key")
elif api == "OpenAI":
model = None
api_key_openai = st.text_input("OpenAI API Key")
st.write("Please select a temperature for the model. The higher the temperature, the more creative the model will be.")
temperature = st.slider("Set the temperature", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
st.write("We provide two different setups for the annotation task. In the first setup (**Test**), you can upload a CSV file with gold labels and evaluate the performance of the model. In the second setup (**Zero-Shot**), you can upload a CSV file without gold labels and use the model to classify the text.")
setup = st.selectbox("Setup", ["Test", "Zero-Shot"])
if setup == "Test":
gold_file = st.file_uploader("Upload Gold Labels CSV file with a text and a label column", type=["csv"])
elif setup == "Zero-Shot":
gold_file = None
zero_file = st.file_uploader("Upload CSV file with a text column", type=["csv"])
st.write("Please enter the prompt template below. You can use the following variables: {text} (text to classify).")
prompt_template = st.text_area("Enter your task description", """Instruction: Identify the sentiment of a text. Please read the text and provide one of these responses: "positive" or "negative".\nText to classify in "positive" or "negative": {text}\nAnswer:""", height=200)
classify_button = st.button("Run Classification/ Annotation")
if classify_button:
if prompt_template:
prompt = PromptTemplate(
if api == "HuggingFace":
if api_key_hug:
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key_hug
llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=model, model_kwargs={"temperature": temperature, "max_length": 128}))
elif not api_key_hug:
st.warning("Please enter your HuggingFace API key to classify the text.")
elif api == "OpenAI":
if api_key_openai:
os.environ["OPENAI_API_KEY"] = api_key_openai
llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=temperature))
elif not api_key_openai:
st.warning("Please enter your OpenAI API key to classify the text.")
if setup == "Zero-Shot":
if zero_file is not None:
df_predicted = classify_csv_zero(zero_file, llm_chain, api)
label="Download CSV",
elif setup == "Test":
if gold_file is not None:
df = pd.read_csv(gold_file, sep=';')
if "label" not in df.columns:
st.warning("Please make sure that the gold labels CSV file contains a column named 'label'.")
df = classify_csv(df, llm_chain, api)
label="Download CSV",
percentage_overlap = evaluate_performance(df)
st.write("**Performance Evaluation**")
st.write(f"Percentage overlap between gold labels and predicted labels: {percentage_overlap:.2f}%")
elif gold_file is None:
st.warning("Please upload a gold labels CSV file to evaluate the performance of the model.")
elif not prompt:
st.warning("Please enter a prompt question to classify the text.")
def main():
st.set_page_config(page_title="PromptCards Playground", page_icon=":pencil2:")
# add a menu to the sidebar
if "current_page" not in st.session_state:
st.session_state.current_page = "homepage"
# Initialize selected_prompt in session_state if not set
if "selected_prompt" not in st.session_state:
st.session_state.selected_prompt = ""
# Add a menu
menu = ["Homepage", "Playground", "Prompt Archive", "Annotator", "About"]
if __name__ == "__main__":