chkla commited on
Commit
20e2dd0
1 Parent(s): 4def8e0

Upload 2 files

Browse files
Files changed (2) hide show
  1. app_anno.py +146 -0
  2. requirements.txt +5 -0
app_anno.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ from langchain import PromptTemplate, HuggingFaceHub, LLMChain
4
+ from langchain.llms import OpenAI
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ import os
7
+ import re
8
+
9
+
10
+ def extract_positive_negative(text):
11
+ pattern = r'\b(?:positive|negative)\b'
12
+ result = re.findall(pattern, text)
13
+ return result
14
+
15
+ def classify_text(text, llm_chain, api):
16
+ if api == "HuggingFace":
17
+ classification = llm_chain.run(str(text))
18
+ elif api == "OpenAI":
19
+ classification = llm_chain.run(str(text))
20
+ classification = re.sub(r'\s', '', classification)
21
+ return classification.lower()
22
+
23
+ def classify_csv(df, llm_chain, api):
24
+ df["label_gold"] = df["label"]
25
+ del df["label"]
26
+ df["label_pred"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api)
27
+ return df
28
+
29
+ def classify_csv_zero(zero_file, llm_chain, api):
30
+ df = pd.read_csv(zero_file, sep=';')
31
+ df["label"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api)
32
+ return df
33
+
34
+ def evaluate_performance(df):
35
+ merged_df = df
36
+ correct_preds = sum(merged_df["label_gold"] == merged_df["label_pred"])
37
+ total_preds = len(merged_df)
38
+ percentage_overlap = correct_preds / total_preds * 100
39
+
40
+ return percentage_overlap
41
+
42
+ def display_home():
43
+ st.write("Please select an API and a model to classify the text. We currently support HuggingFace and OpenAI.")
44
+ api = st.selectbox("Select an API", ["HuggingFace", "OpenAI"])
45
+
46
+ if api == "HuggingFace":
47
+ model = st.selectbox("Select a model", ["google/flan-t5-xl", "databricks/dolly-v1-6b"])
48
+ api_key_hug = st.text_input("HuggingFace API Key")
49
+ elif api == "OpenAI":
50
+ model = None
51
+ api_key_openai = st.text_input("OpenAI API Key")
52
+
53
+ st.write("Please select a temperature for the model. The higher the temperature, the more creative the model will be.")
54
+ temperature = st.slider("Set the temperature", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
55
+
56
+ st.write("We provide two different setups for the annotation task. In the first setup (**Test**), you can upload a CSV file with gold labels and evaluate the performance of the model. In the second setup (**Zero-Shot**), you can upload a CSV file without gold labels and use the model to classify the text.")
57
+ setup = st.selectbox("Setup", ["Test", "Zero-Shot"])
58
+
59
+ if setup == "Test":
60
+ gold_file = st.file_uploader("Upload Gold Labels CSV file with a text and a label column", type=["csv"])
61
+ elif setup == "Zero-Shot":
62
+ gold_file = None
63
+ zero_file = st.file_uploader("Upload CSV file with a text column", type=["csv"])
64
+
65
+ st.write("Please enter the prompt template below. You can use the following variables: {text} (text to classify).")
66
+ prompt_template = st.text_area("Enter your task description", """Instruction: Identify the sentiment of a text. Please read the text and provide one of these responses: "positive" or "negative".\nText to classify in "positive" or "negative": {text}\nAnswer:""", height=200)
67
+
68
+ classify_button = st.button("Run Classification/ Annotation")
69
+
70
+ if classify_button:
71
+ if prompt_template:
72
+ prompt = PromptTemplate(
73
+ template=prompt_template,
74
+ input_variables=["text"]
75
+ )
76
+
77
+ if api == "HuggingFace":
78
+ if api_key_hug:
79
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key_hug
80
+ llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=model, model_kwargs={"temperature": temperature, "max_length": 128}))
81
+ elif not api_key_hug:
82
+ st.warning("Please enter your HuggingFace API key to classify the text.")
83
+ elif api == "OpenAI":
84
+ if api_key_openai:
85
+ os.environ["OPENAI_API_KEY"] = api_key_openai
86
+ llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=temperature))
87
+ elif not api_key_openai:
88
+ st.warning("Please enter your OpenAI API key to classify the text.")
89
+
90
+ if setup == "Zero-Shot":
91
+ if zero_file is not None:
92
+ df_predicted = classify_csv_zero(zero_file, llm_chain, api)
93
+ st.write(df_predicted)
94
+ st.download_button(
95
+ label="Download CSV",
96
+ data=df_predicted.to_csv(index=False),
97
+ file_name="classified_zero-shot_data.csv",
98
+ mime="text/csv"
99
+ )
100
+ elif setup == "Test":
101
+ if gold_file is not None:
102
+ df = pd.read_csv(gold_file, sep=';')
103
+ if "label" not in df.columns:
104
+ st.warning("Please make sure that the gold labels CSV file contains a column named 'label'.")
105
+ else:
106
+ df = classify_csv(df, llm_chain, api)
107
+ st.write(df)
108
+ st.download_button(
109
+ label="Download CSV",
110
+ data=df.to_csv(index=False),
111
+ file_name="classified_test_data.csv",
112
+ mime="text/csv"
113
+ )
114
+ percentage_overlap = evaluate_performance(df)
115
+ st.write("**Performance Evaluation**")
116
+ st.write(f"Percentage overlap between gold labels and predicted labels: {percentage_overlap:.2f}%")
117
+ elif gold_file is None:
118
+ st.warning("Please upload a gold labels CSV file to evaluate the performance of the model.")
119
+ elif not prompt:
120
+ st.warning("Please enter a prompt question to classify the text.")
121
+
122
+ def main():
123
+ st.set_page_config(page_title="PromptCards Playground", page_icon=":pencil2:")
124
+ st.title("AInnotator")
125
+
126
+ # add a menu to the sidebar
127
+ if "current_page" not in st.session_state:
128
+ st.session_state.current_page = "homepage"
129
+
130
+ # Initialize selected_prompt in session_state if not set
131
+ if "selected_prompt" not in st.session_state:
132
+ st.session_state.selected_prompt = ""
133
+
134
+ # Add a menu
135
+ menu = ["Homepage", "Playground", "Prompt Archive", "Annotator", "About"]
136
+ st.sidebar.title("About")
137
+ st.sidebar.write("AInnotator 🤖🏷️ is a tool for creating artificial labels/ annotations. It is based on the concept of PromptCards, which are small, self-contained descriptions of a task that can be used to generate labels for a wide range of NLP tasks. Check out the GitHub repository and the PromptCards Archive for more information.")
138
+ st.sidebar.write("---")
139
+ st.sidebar.write("Check out the [PromptCards archive]() to find a wide range of prompts for different NLP tasks.")
140
+ st.sidebar.write("---")
141
+ st.sidebar.write("Made with ❤️ and 🤖.")
142
+
143
+ display_home()
144
+
145
+ if __name__ == "__main__":
146
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ langchain
2
+ pandas
3
+ streamlit
4
+ transformers
5
+ sklearn