Wedyan2023's picture
Update app.py
dc141e7 verified
raw
history blame
5.83 kB
import streamlit as st
from openai import OpenAI
# Initialize session state
if 'messages' not in st.session_state:
st.session_state.messages = []
# Function to generate system prompt based on user inputs
def create_system_prompt(classification_type, num_to_generate, domain, min_words, max_words, labels):
system_prompt = f"You are a professional {classification_type.lower()} expert. Your role is to generate exactly {num_to_generate} data examples for {domain}. "
system_prompt += f"Each example should consist of between {min_words} and {max_words} words. "
system_prompt += "Use the following labels: " + ", ".join(labels) + ". Please do not add any extra commentary or explanation. "
system_prompt += "Format each example like this: \nExample: <text>, Label: <label>\n"
return system_prompt
# OpenAI client setup (replace with your OpenAI API credentials)
client = OpenAI(api_key='YOUR_API_KEY')
# App title
st.title("Data Generation for Classification")
# Choice between Data Generation or Data Labeling
mode = st.radio("Choose Task:", ["Data Generation", "Data Labeling"])
if mode == "Data Generation":
# Step 1: Choose Classification Type
classification_type = st.radio(
"Select Classification Type:",
["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"]
)
# Step 2: Choose labels based on classification type
if classification_type == "Sentiment Analysis":
labels = ["Positive", "Negative", "Neutral"]
elif classification_type == "Binary Classification":
class1 = st.text_input("Enter First Class for Binary Classification")
class2 = st.text_input("Enter Second Class for Binary Classification")
labels = [class1, class2]
elif classification_type == "Multi-Class Classification":
num_classes = st.slider("Number of Classes (Max 10):", 2, 10, 3)
labels = [st.text_input(f"Enter Class {i+1}") for i in range(num_classes)]
# Step 3: Choose the domain
domain = st.radio(
"Select Domain:",
["Restaurant reviews", "E-commerce reviews", "Custom"]
)
if domain == "Custom":
domain = st.text_input("Enter Custom Domain")
# Step 4: Specify example length (min and max words)
min_words = st.slider("Minimum Words per Example", 10, 90, 20)
max_words = st.slider("Maximum Words per Example", 10, 90, 40)
# Step 5: Ask if user wants few-shot examples
use_few_shot = st.checkbox("Use Few-Shot Examples?")
few_shot_examples = []
if use_few_shot:
num_few_shots = st.slider("Number of Few-Shot Examples (Max 5):", 1, 5, 2)
for i in range(num_few_shots):
example_text = st.text_area(f"Enter Example {i+1} Text")
example_label = st.selectbox(f"Select Label for Example {i+1}", labels)
few_shot_examples.append(f"Example: {example_text}, Label: {example_label}")
# Step 6: Specify the number of examples to generate
num_to_generate = st.number_input("Number of Examples to Generate", min_value=1, max_value=50, value=10)
# Step 7: Generate system prompt based on the inputs
system_prompt = create_system_prompt(classification_type, num_to_generate, domain, min_words, max_words, labels)
if st.button("Generate Examples"):
all_generated_examples = []
remaining_examples = num_to_generate
with st.spinner("Generating..."):
while remaining_examples > 0:
chunk_size = min(remaining_examples, 5)
try:
# Add system and user messages to session state
st.session_state.messages.append({"role": "system", "content": system_prompt})
# Add few-shot examples to the system prompt
if few_shot_examples:
for example in few_shot_examples:
st.session_state.messages.append({"role": "user", "content": example})
# Stream API request to generate examples
stream = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": m["role"], "content": m["content"]}
for m in st.session_state.messages
],
temperature=0.7,
stream=True,
max_tokens=3000,
)
# Capture streamed response
response = ""
for chunk in stream:
if 'content' in chunk['choices'][0]['delta']:
response += chunk['choices'][0]['delta']['content']
# Split response into individual examples by "Example: "
generated_examples = response.split("Example: ")[1:chunk_size+1] # Extract up to the chunk size
# Clean up the extracted examples
cleaned_examples = [f"Example {i+1}: {ex.strip()}" for i, ex in enumerate(generated_examples)]
# Store the new examples
all_generated_examples.extend(cleaned_examples)
remaining_examples -= chunk_size
except Exception as e:
st.error("Error during generation.")
st.write(e)
break
# Display all generated examples properly formatted
for idx, example in enumerate(all_generated_examples):
st.write(f"Example {idx+1}: {example.strip()}")
# Clear session state to avoid repetition of old prompts
st.session_state.messages = [] # Reset after each generation