Spaces:
Paused
Paused
#!/usr/bin/env python3 | |
""" | |
Simplified ChatGPT inference script for yes/no classification | |
Forces single token responses for consistent results | |
""" | |
from multiprocessing.pool import ThreadPool | |
import pandas as pd | |
from openai import OpenAI | |
import time | |
from datetime import datetime | |
from app import format_prompt | |
# Initialize OpenAI client | |
MICROSERVICES_FIVERR_OS_FIVERR_OS_BACKEND_CREDENTIALS_OPENAI_API_KEY="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" | |
client = OpenAI(api_key=MICROSERVICES_FIVERR_OS_FIVERR_OS_BACKEND_CREDENTIALS_OPENAI_API_KEY) | |
def get_prediction(query, title, content, model="gpt-5-nano"): | |
"""Get yes/no prediction from ChatGPT""" | |
prompt = format_prompt(query, title, content) | |
try: | |
response = client.chat.completions.create( | |
model=model, | |
messages=[ | |
{"role": "user", "content": prompt} | |
], | |
) | |
# Get prediction | |
prediction = response.choices[0].message.content.strip().lower() | |
# Ensure it's yes or no | |
if prediction not in ['yes', 'no']: | |
prediction = 'error' | |
print(prediction) | |
return prediction | |
except Exception as e: | |
print(f"API Error: {e}") | |
return 'error' | |
def main(): | |
csv_path ="sampled_db.csv" | |
# Load CSV | |
print(f"Loading {csv_path}...") | |
df = pd.read_csv(csv_path) | |
# Process each row | |
prds = [(str(row['query_text']),str(row['title']),str(row['text'])) for idx, row in df.iterrows()] | |
predictions = ThreadPool(100).starmap(get_prediction,prds) | |
df['prediction'] = predictions | |
conf_matrix = pd.crosstab( | |
index=df['label'], # True labels | |
columns=df['prediction'], # Predicted labels | |
rownames=['Actual'], | |
colnames=['Predicted'] | |
) | |
accuracy = (df['prediction']=='yes')&((df['label']=='easy_positive')|(df['label']=='hard_positive'))|(df['prediction']=='no')&((df['label']=='easy_negative')|(df['label']=='hard_negative')) | |
print(conf_matrix) | |
print(accuracy.mean()) | |
output = f"chatgpt_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" | |
df.to_csv(output, index=False) | |
print(f"\nSaved to: {output}") | |
# Show summary | |
print("\nResults:") | |
print(df['prediction'].value_counts()) | |
def make_sample_db(): | |
df = pd.read_csv(rf"train_datasets_creation/full_train_dataset.csv") | |
dfs = [df[df['label']==d].sample(100) for d in df['label'].unique()] | |
df = pd.concat(dfs).reset_index() | |
df.to_csv(f"sample_db_{datetime.now().isoformat()}.csv") | |
if __name__ == "__main__": | |
make_sample_db() |