Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import streamlit as st | |
| from langchain import PromptTemplate, HuggingFaceHub, LLMChain | |
| from langchain.llms import OpenAI | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import os | |
| import re | |
| def extract_positive_negative(text): | |
| pattern = r'\b(?:positive|negative)\b' | |
| result = re.findall(pattern, text) | |
| return result | |
| def classify_text(text, llm_chain, api): | |
| if api == "HuggingFace": | |
| classification = llm_chain.run(str(text)) | |
| elif api == "OpenAI": | |
| classification = llm_chain.run(str(text)) | |
| classification = re.sub(r'\s', '', classification) | |
| return classification.lower() | |
| def classify_csv(df, llm_chain, api): | |
| df["label_gold"] = df["label"] | |
| del df["label"] | |
| df["label_pred"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api) | |
| return df | |
| def classify_csv_zero(zero_file, llm_chain, api): | |
| df = pd.read_csv(zero_file, sep=';') | |
| df["label"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api) | |
| return df | |
| def evaluate_performance(df): | |
| merged_df = df | |
| correct_preds = sum(merged_df["label_gold"] == merged_df["label_pred"]) | |
| total_preds = len(merged_df) | |
| percentage_overlap = correct_preds / total_preds * 100 | |
| return percentage_overlap | |
| def display_home(): | |
| st.write("Please select an API and a model to classify the text. We currently support HuggingFace and OpenAI.") | |
| api = st.selectbox("Select an API", ["HuggingFace", "OpenAI"]) | |
| if api == "HuggingFace": | |
| model = st.selectbox("Select a model", ["google/flan-t5-xl", "databricks/dolly-v1-6b"]) | |
| api_key_hug = st.text_input("HuggingFace API Key") | |
| elif api == "OpenAI": | |
| model = None | |
| api_key_openai = st.text_input("OpenAI API Key") | |
| st.write("Please select a temperature for the model. The higher the temperature, the more creative the model will be.") | |
| temperature = st.slider("Set the temperature", min_value=0.0, max_value=1.0, value=0.0, step=0.01) | |
| st.write("We provide two different setups for the annotation task. In the first setup (**Test**), you can upload a CSV file with gold labels and evaluate the performance of the model. In the second setup (**Zero-Shot**), you can upload a CSV file without gold labels and use the model to classify the text.") | |
| setup = st.selectbox("Setup", ["Test", "Zero-Shot"]) | |
| if setup == "Test": | |
| gold_file = st.file_uploader("Upload Gold Labels CSV file with a text and a label column", type=["csv"]) | |
| elif setup == "Zero-Shot": | |
| gold_file = None | |
| zero_file = st.file_uploader("Upload CSV file with a text column", type=["csv"]) | |
| st.write("Please enter the prompt template below. You can use the following variables: {text} (text to classify).") | |
| prompt_template = st.text_area("Enter your task description", """Instruction: Identify the sentiment of a text. Please read the text and provide one of these responses: "positive" or "negative".\nText to classify in "positive" or "negative": {text}\nAnswer:""", height=200) | |
| classify_button = st.button("Run Classification/ Annotation") | |
| if classify_button: | |
| if prompt_template: | |
| prompt = PromptTemplate( | |
| template=prompt_template, | |
| input_variables=["text"] | |
| ) | |
| if api == "HuggingFace": | |
| if api_key_hug: | |
| os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key_hug | |
| llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=model, model_kwargs={"temperature": temperature, "max_length": 128})) | |
| elif not api_key_hug: | |
| st.warning("Please enter your HuggingFace API key to classify the text.") | |
| elif api == "OpenAI": | |
| if api_key_openai: | |
| os.environ["OPENAI_API_KEY"] = api_key_openai | |
| llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=temperature)) | |
| elif not api_key_openai: | |
| st.warning("Please enter your OpenAI API key to classify the text.") | |
| if setup == "Zero-Shot": | |
| if zero_file is not None: | |
| df_predicted = classify_csv_zero(zero_file, llm_chain, api) | |
| st.write(df_predicted) | |
| st.download_button( | |
| label="Download CSV", | |
| data=df_predicted.to_csv(index=False), | |
| file_name="classified_zero-shot_data.csv", | |
| mime="text/csv" | |
| ) | |
| elif setup == "Test": | |
| if gold_file is not None: | |
| df = pd.read_csv(gold_file, sep=';') | |
| if "label" not in df.columns: | |
| st.warning("Please make sure that the gold labels CSV file contains a column named 'label'.") | |
| else: | |
| df = classify_csv(df, llm_chain, api) | |
| st.write(df) | |
| st.download_button( | |
| label="Download CSV", | |
| data=df.to_csv(index=False), | |
| file_name="classified_test_data.csv", | |
| mime="text/csv" | |
| ) | |
| percentage_overlap = evaluate_performance(df) | |
| st.write("**Performance Evaluation**") | |
| st.write(f"Percentage overlap between gold labels and predicted labels: {percentage_overlap:.2f}%") | |
| elif gold_file is None: | |
| st.warning("Please upload a gold labels CSV file to evaluate the performance of the model.") | |
| elif not prompt: | |
| st.warning("Please enter a prompt question to classify the text.") | |
| def main(): | |
| st.set_page_config(page_title="PromptCards Playground", page_icon=":pencil2:") | |
| st.title("AInnotator") | |
| # add a menu to the sidebar | |
| if "current_page" not in st.session_state: | |
| st.session_state.current_page = "homepage" | |
| # Initialize selected_prompt in session_state if not set | |
| if "selected_prompt" not in st.session_state: | |
| st.session_state.selected_prompt = "" | |
| # Add a menu | |
| menu = ["Homepage", "Playground", "Prompt Archive", "Annotator", "About"] | |
| st.sidebar.title("About") | |
| st.sidebar.write("AInnotator 🤖🏷️ is a tool for creating artificial labels/ annotations. It is based on the concept of PromptCards, which are small, self-contained descriptions of a task that can be used to generate labels for a wide range of NLP tasks. Check out the GitHub repository and the PromptCards Archive for more information.") | |
| st.sidebar.write("---") | |
| st.sidebar.write("Check out the [PromptCards archive](https://huggingface.co/spaces/chkla/AnnotationPromptCards) to find a wide range of prompts for different NLP tasks.") | |
| st.sidebar.write("---") | |
| st.sidebar.write("Made with ❤️ and 🤖.") | |
| display_home() | |
| if __name__ == "__main__": | |
| main() | |