Spaces:
Sleeping
Sleeping
from saged import Pipeline | |
from tqdm import tqdm | |
from pathlib import Path | |
from saged import SAGEDData as dt | |
import streamlit as st | |
import json | |
import http.client | |
from openai import AzureOpenAI | |
import ollama | |
import time # Use time.sleep to simulate processing steps | |
import logging | |
from io import StringIO | |
import sys | |
# Create a custom logging handler to capture log messages | |
class StreamlitLogHandler(logging.Handler): | |
def __init__(self): | |
super().__init__() | |
self.log_capture_string = StringIO() | |
def emit(self, record): | |
# Write each log message to the StringIO buffer | |
message = self.format(record) | |
self.log_capture_string.write(message + "\n") | |
def get_logs(self): | |
# Return the log contents | |
return self.log_capture_string.getvalue() | |
def clear_logs(self): | |
# Clear the log buffer | |
self.log_capture_string.truncate(0) | |
self.log_capture_string.seek(0) | |
# Define ContentFormatter class | |
class ContentFormatter: | |
def chat_completions(text, settings_params): | |
message = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": text} | |
] | |
data = {"messages": message, **settings_params} | |
return json.dumps(data) | |
# Define OllamaModel (For local Ollama interaction) | |
class OllamaModel: | |
def __init__(self, base_model='llama3', system_prompt='You are a helpful assistant', model_name='llama3o', | |
**kwargs): | |
self.base_model = base_model | |
self.model_name = model_name | |
self.model_create(model_name, system_prompt, base_model, **kwargs) | |
def model_create(self, model_name, system_prompt, base_model, **kwargs): | |
modelfile = f'FROM {base_model}\nSYSTEM {system_prompt}\n' | |
if kwargs: | |
for key, value in kwargs.items(): | |
modelfile += f'PARAMETER {key.lower()} {value}\n' | |
ollama.create(model=model_name, modelfile=modelfile) | |
def invoke(self, prompt): | |
answer = ollama.generate(model=self.model_name, prompt=prompt) | |
return answer['response'] | |
# Define GPTAgent (For OpenAI GPT models) | |
class GPTAgent: | |
def __init__(self, model_name, azure_key, azure_version, azure_endpoint, deployment_name): | |
self.client = AzureOpenAI( | |
api_key=azure_key, | |
api_version=azure_version, | |
azure_endpoint=azure_endpoint | |
) | |
self.deployment_name = deployment_name | |
def invoke(self, prompt, settings_params=None): | |
if not settings_params: | |
settings_params = {} | |
formatted_input = ContentFormatter.chat_completions(prompt, settings_params) | |
response = self.client.chat.completions.create( | |
model=self.deployment_name, | |
messages=json.loads(formatted_input)['messages'], | |
**settings_params | |
) | |
return response.choices[0].message.content | |
# Define AzureAgent (For Azure OpenAI models) | |
class AzureAgent: | |
def __init__(self, model_name, azure_uri, azure_api_key): | |
self.azure_uri = azure_uri | |
self.headers = { | |
'Authorization': f"Bearer {azure_api_key}", | |
'Content-Type': 'application/json' | |
} | |
self.chat_formatter = ContentFormatter | |
def invoke(self, prompt, settings_params=None): | |
if not settings_params: | |
settings_params = {} | |
body = self.chat_formatter.chat_completions(prompt, {**settings_params}) | |
conn = http.client.HTTPSConnection(self.azure_uri) | |
conn.request("POST", '/v1/chat/completions', body=body, headers=self.headers) | |
response = conn.getresponse() | |
data = response.read() | |
conn.close() | |
decoded_data = data.decode("utf-8") | |
parsed_data = json.loads(decoded_data) | |
content = parsed_data["choices"][0]["message"]["content"] | |
return content | |
# Renew Source Finder Button | |
def renew_source_finder(domain, concept_list): | |
if 'generated_synthetic_files' in st.session_state: | |
del st.session_state['generated_synthetic_files'] | |
if not domain or not concept_list: | |
st.error("Please fill in all the required fields before proceeding.") | |
else: | |
with st.spinner("Renewing source info files..."): | |
base_path = Path('data/customized/source_finder/') | |
for concept in concept_list: | |
file_path = base_path / f'{domain}_{concept}_source_finder.json' | |
if file_path.exists(): | |
try: | |
file_path.unlink() # Delete the file | |
st.info(f"Deleted source info file: {file_path}") | |
except Exception as e: | |
st.error(f"An error occurred while deleting the file {file_path}: {e}") | |
st.success("Source info files renewal completed!") | |
def create_source_finder(domain, concept): | |
source_specification_item = f"data/customized/local_files/{domain}/{concept}.txt" | |
if not Path(source_specification_item).exists(): | |
st.warning(f"Local file does not exist: {source_specification_item}") | |
instance = dt.create_data(domain, concept, 'source_finder') | |
instance.data[0]['keywords'] = {concept: dt.default_keyword_metadata.copy()} | |
category_shared_source_item = dt.default_source_item.copy() | |
category_shared_source_item['source_type'] = "local_paths" | |
category_shared_source_item['source_specification'] = [source_specification_item] | |
instance.data[0]['category_shared_source'] = [category_shared_source_item] | |
return instance.data.copy() | |
def check_and_create_source_files(domain, concept_list): | |
""" | |
Checks if the required source finder files exist for each concept in the domain. | |
If a file does not exist or is invalid, it creates an empty JSON file for that concept. | |
""" | |
base_path = Path('data/customized/source_finder/') | |
base_path.mkdir(parents=True, exist_ok=True) | |
for concept in concept_list: | |
file_path = base_path / f'{domain}_{concept}_source_finder.json' | |
if not file_path.exists(): | |
# Create a new source finder file using create_source_finder | |
data = create_source_finder(domain, concept) | |
with open(file_path, 'w', encoding='utf-8') as f: | |
json.dump(data, f, indent=4) | |
st.info(f"Created missing source finder file: {file_path}") | |
else: | |
# Attempt to load the file to verify its validity | |
instance = dt.load_file(domain, concept, 'source_finder', file_path) | |
if instance is None: | |
# If loading fails, create a new valid file | |
data = create_source_finder(domain, concept) | |
with open(file_path, 'w', encoding='utf-8') as f: | |
json.dump(data, f, indent=4) | |
st.info(f"Recreated invalid source finder file: {file_path}") | |
def clean_spaces(data): | |
""" | |
Removes trailing or leading spaces from a string or from each element in a list. | |
""" | |
if isinstance(data, str): | |
return data.strip() | |
elif isinstance(data, list): | |
return [item.strip() if isinstance(item, str) else item for item in data] | |
else: | |
raise TypeError("Input should be either a string or a list of strings") | |
def create_replacement_dict(concept_list, replacer): | |
replacement = {} | |
for concept in concept_list: | |
replacement[concept] = {} | |
for company in replacer: | |
replacement[concept][company] = {concept: company} | |
return replacement | |
# Title of the app | |
st.title("SAGED-bias Benchmark-Building Demo") | |
# Initialize session state variables | |
if 'domain' not in st.session_state: | |
st.session_state['domain'] = None | |
if 'concept_list' not in st.session_state: | |
st.session_state['concept_list'] = None | |
if 'gpt_model' not in st.session_state: | |
st.session_state['gpt_model'] = None | |
if 'azure_model' not in st.session_state: | |
st.session_state['azure_model'] = None | |
if 'ollama_model' not in st.session_state: | |
st.session_state['ollama_model'] = None | |
# Sidebar: Model Selection | |
with st.sidebar: | |
st.header("Model Configuration") | |
# Selection of which model to use | |
model_selection = st.radio("Select Model Type", ['GPT-Azure', 'Azure', 'Ollama']) | |
# Collapsible Additional Configuration Section | |
with st.expander("Model Configuration"): | |
if model_selection == 'Ollama': | |
# Ollama Configuration | |
ollama_deployment_name = st.text_input("Enter Ollama Model Deployment Name", placeholder="e.g., llama3") | |
ollama_system_prompt = st.text_input("Enter System Prompt for Ollama", | |
placeholder="e.g., You are a helpful assistant.") | |
if ollama_deployment_name and ollama_system_prompt: | |
confirm_ollama = st.button("Confirm Ollama Configuration") | |
if confirm_ollama: | |
st.session_state['ollama_model'] = OllamaModel( | |
model_name=ollama_deployment_name, | |
system_prompt=ollama_system_prompt | |
) | |
st.success("Ollama model configured successfully.") | |
else: | |
st.warning("Please provide both Ollama deployment name and system prompt.") | |
elif model_selection == 'GPT-Azure' or model_selection == 'Azure': | |
# GPT / Azure Configuration | |
gpt_azure_endpoint = st.text_input("Enter Azure Endpoint URL", | |
placeholder="e.g., https://your-resource-name.openai.azure.com/") | |
gpt_azure_api_key = st.text_input("Enter Azure API Key", type="password") | |
gpt_azure_model_name = st.text_input("Enter Azure Model Name", placeholder="e.g., GPT-3.5-turbo") | |
gpt_azure_deployment_name = st.text_input("Enter Azure Deployment Name", | |
placeholder="e.g., gpt-3-5-deployment") | |
if gpt_azure_endpoint and gpt_azure_api_key and gpt_azure_model_name and gpt_azure_deployment_name: | |
confirm_gpt_azure = st.button("Confirm GPT/Azure Configuration") | |
if confirm_gpt_azure: | |
if model_selection == 'GPT-Azure': | |
st.session_state['gpt_model'] = GPTAgent( | |
model_name=gpt_azure_model_name, | |
azure_key=gpt_azure_api_key, | |
azure_version='2023-05-15', # Update if necessary | |
azure_endpoint=gpt_azure_endpoint, | |
deployment_name=gpt_azure_deployment_name | |
) | |
st.success("GPT model configured successfully.") | |
elif model_selection == 'Azure': | |
st.session_state['azure_model'] = AzureAgent( | |
model_name=gpt_azure_model_name, | |
azure_uri=gpt_azure_endpoint, | |
azure_api_key=gpt_azure_api_key | |
) | |
st.success("Azure model configured successfully.") | |
else: | |
st.warning("Please provide all fields for GPT/Azure configuration.") | |
# Main interaction based on configured model | |
if st.session_state.get('ollama_model'): | |
model = st.session_state['ollama_model'] | |
elif st.session_state.get('gpt_model'): | |
model = st.session_state['gpt_model'] | |
elif st.session_state.get('azure_model'): | |
model = st.session_state['azure_model'] | |
else: | |
model = None | |
# User input: Domain and Concepts | |
with st.form(key='domain_concept_form'): | |
domain = clean_spaces( | |
st.text_input("Enter the domain: (e.g., Stocks, Education)", placeholder="Enter domain here...")) | |
# User input: Concepts | |
concept_text = st.text_area("Enter the concepts (separated by commas):", | |
placeholder="e.g., excel-stock, ok-stock, bad-stock") | |
concept_list = clean_spaces(concept_text.split(',')) | |
submit_button = st.form_submit_button(label='Confirm Domain and Concepts') | |
if submit_button: | |
if not domain: | |
st.warning("Please enter a domain.") | |
elif not concept_list or concept_text.strip() == "": | |
st.warning("Please enter at least one concept.") | |
else: | |
st.session_state['domain'] = domain | |
st.session_state['concept_list'] = concept_list | |
st.success("Domain and concepts confirmed.") | |
# Display further options only after domain and concepts are confirmed | |
if st.session_state['domain'] and st.session_state['concept_list']: | |
with st.expander("Additional Options"): | |
# User input: Method | |
scraper_method = st.radio("Select the scraper method:", (('wiki', 'local_files', 'synthetic_files'))) | |
# Initiate the source_finder_requirement and keyword_finder_requirement if 'wiki' is selected | |
if scraper_method == 'wiki': | |
st.session_state['keyword_finder_requirement'] = True | |
st.session_state['source_finder_requirement'] = True | |
st.session_state['check_source_finder'] = False | |
# File upload for each concept if 'local_files' is selected | |
if scraper_method == 'local_files': | |
uploaded_files = {} | |
st.session_state['keyword_finder_requirement'] = False | |
st.session_state['source_finder_requirement'] = False | |
st.session_state['check_source_finder'] = True | |
for concept in st.session_state['concept_list']: | |
uploaded_file = st.file_uploader(f"Upload file for concept '{concept}':", type=['txt'], | |
key=f"file_{concept}") | |
if uploaded_file: | |
uploaded_files[concept] = uploaded_file | |
# Save uploaded file | |
save_path = Path(f"data/customized/local_files/{st.session_state['domain']}/{concept}.txt") | |
save_path.parent.mkdir(parents=True, exist_ok=True) | |
with open(save_path, 'wb') as f: | |
f.write(uploaded_file.getbuffer()) | |
st.success(f"File for concept '{concept}' saved successfully.") | |
# Generate synthetic files if 'synthetic_files' is selected | |
if scraper_method == 'synthetic_files': | |
scraper_method = 'local_files' | |
st.session_state['keyword_finder_requirement'] = False | |
st.session_state['source_finder_requirement'] = False | |
st.session_state['check_source_finder'] = True | |
if 'generated_synthetic_files' not in st.session_state: | |
st.session_state['generated_synthetic_files'] = set() | |
prompt_inputs = {} | |
for concept in st.session_state['concept_list']: | |
if concept not in st.session_state['generated_synthetic_files']: | |
prompt_inputs[concept] = st.text_input( | |
f"Enter the prompt for concept '{concept}':", | |
value=f"Write a long article introducing the {concept} in the {st.session_state['domain']}. Use the {concept} as much as possible.", | |
key=f"prompt_{concept}" | |
) | |
if st.button("Generate Synthetic Files for All Concepts"): | |
if model: | |
for concept, prompt in prompt_inputs.items(): | |
if prompt: | |
with st.spinner(f"Generating content for concept '{concept}'..."): | |
synthetic_content = model.invoke(prompt) | |
save_path = Path( | |
f"data/customized/local_files/{st.session_state['domain']}/{concept}.txt") | |
save_path.parent.mkdir(parents=True, exist_ok=True) | |
with open(save_path, 'w', encoding='utf-8') as f: | |
f.write(synthetic_content) | |
st.session_state['generated_synthetic_files'].add(concept) | |
st.success(f"Synthetic file for concept '{concept}' created successfully.") | |
else: | |
st.warning("Please configure a model to generate synthetic files.") | |
# User input: Prompt Method | |
prompt_method = st.radio("Select the prompt method:", ('split_sentences', 'questions'), index = 0) | |
# User input: Max Benchmark Length | |
max_benchmark_length = st.slider("Select the maximum prompts per concepts:", 1, 199, 10) | |
# User input: Branching | |
branching = st.radio("Enable branching:", ('Yes', 'No'), index=1) | |
branching_enabled = True if branching == 'Yes' else False | |
# User input: Replacer (only if branching is enabled) | |
replacer = [] | |
replacement = {} | |
if branching_enabled: | |
replacer_text = st.text_area("Enter the replacer list (list of strings, separated by commas):", | |
placeholder="e.g., Company A, Company B") | |
replacer = clean_spaces(replacer_text.split(',')) | |
replacement = create_replacement_dict(st.session_state['concept_list'], replacer) | |
# Configuration | |
concept_specified_config = { | |
x: {'keyword_finder': {'manual_keywords': [x]}} for x in st.session_state['concept_list'] | |
} | |
concept_configuration = { | |
'keyword_finder': { | |
'require': st.session_state['keyword_finder_requirement'], | |
'keyword_number': 1, | |
}, | |
'source_finder': { | |
'require': st.session_state['source_finder_requirement'], | |
'scrap_number': 10, | |
'method': scraper_method, | |
}, | |
'scraper': { | |
'require': True, | |
'method': scraper_method, | |
}, | |
'prompt_maker': { | |
'method': prompt_method, | |
'generation_function': model.invoke if model else None, | |
'max_benchmark_length': max_benchmark_length, | |
}, | |
} | |
domain_configuration = { | |
'categories': st.session_state['concept_list'], | |
'branching': branching_enabled, | |
'branching_config': { | |
'generation_function': model.invoke if model else None, | |
'keyword_reference': st.session_state['concept_list'], | |
'replacement_descriptor_require': False, | |
'replacement_description': replacement, | |
'branching_pairs': 'not all', | |
'direction': 'not both', | |
}, | |
'shared_config': concept_configuration, | |
'category_specified_config': concept_specified_config | |
} | |
# Renew Source Finder Button | |
if st.button('Renew Source info'): | |
renew_source_finder(st.session_state['domain'], st.session_state['concept_list']) | |
# Save the original stdout to print to the terminal if needed later | |
original_stdout = sys.stdout | |
# Define StreamToText to capture and display logs in real-time within Streamlit only | |
class StreamToText: | |
def __init__(self): | |
self.output = StringIO() | |
def write(self, message): | |
if message.strip(): # Avoid adding empty messages | |
# Only append to Streamlit display, not the terminal | |
st.session_state.log_messages.append(message.strip()) | |
log_placeholder.text("\n".join(st.session_state.log_messages)) # Flush updated logs | |
def flush(self): | |
pass # Required for compatibility with sys.stdout | |
# Initialize session state for log messages | |
if 'log_messages' not in st.session_state: | |
st.session_state.log_messages = [] | |
# Replace sys.stdout with our custom StreamToText instance | |
stream_to_text = StreamToText() | |
sys.stdout = stream_to_text | |
# Placeholder for displaying logs within a collapsible expander | |
with st.expander("Show Logs", expanded=False): | |
log_placeholder = st.empty() # Placeholder for dynamic log display | |
# Define the Create Benchmark button | |
if st.button("Create a Benchmark"): | |
st.session_state.log_messages = [] # Clear previous logs | |
with st.spinner("Creating benchmark..."): | |
if st.session_state['check_source_finder']: | |
# Check for relevant materials | |
check_and_create_source_files(st.session_state['domain'], st.session_state['concept_list']) | |
try: | |
# Display progress bar and log messages | |
progress_bar = st.progress(0) | |
for i in tqdm(range(1, 101)): | |
progress_bar.progress(i) | |
time.sleep(0.05) # Short delay to simulate processing time | |
# Run the benchmark creation function | |
benchmark = Pipeline.domain_benchmark_building(st.session_state['domain'], domain_configuration) | |
st.success("Benchmark creation completed!") | |
st.dataframe(benchmark.data) | |
except Exception as e: | |
st.error(f"An error occurred during benchmark creation: {e}") | |