train-mbed / inference_chatgpt_simple.py
amos1088's picture
no
deae167
#!/usr/bin/env python3
"""
Simplified ChatGPT inference script for yes/no classification
Forces single token responses for consistent results
"""
from multiprocessing.pool import ThreadPool
import pandas as pd
from openai import OpenAI
import time
from datetime import datetime
from app import format_prompt
# Initialize OpenAI client
MICROSERVICES_FIVERR_OS_FIVERR_OS_BACKEND_CREDENTIALS_OPENAI_API_KEY="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"
client = OpenAI(api_key=MICROSERVICES_FIVERR_OS_FIVERR_OS_BACKEND_CREDENTIALS_OPENAI_API_KEY)
def get_prediction(query, title, content, model="gpt-5-nano"):
"""Get yes/no prediction from ChatGPT"""
prompt = format_prompt(query, title, content)
try:
response = client.chat.completions.create(
model=model,
messages=[
{"role": "user", "content": prompt}
],
)
# Get prediction
prediction = response.choices[0].message.content.strip().lower()
# Ensure it's yes or no
if prediction not in ['yes', 'no']:
prediction = 'error'
print(prediction)
return prediction
except Exception as e:
print(f"API Error: {e}")
return 'error'
def main():
csv_path ="sampled_db.csv"
# Load CSV
print(f"Loading {csv_path}...")
df = pd.read_csv(csv_path)
# Process each row
prds = [(str(row['query_text']),str(row['title']),str(row['text'])) for idx, row in df.iterrows()]
predictions = ThreadPool(100).starmap(get_prediction,prds)
df['prediction'] = predictions
conf_matrix = pd.crosstab(
index=df['label'], # True labels
columns=df['prediction'], # Predicted labels
rownames=['Actual'],
colnames=['Predicted']
)
accuracy = (df['prediction']=='yes')&((df['label']=='easy_positive')|(df['label']=='hard_positive'))|(df['prediction']=='no')&((df['label']=='easy_negative')|(df['label']=='hard_negative'))
print(conf_matrix)
print(accuracy.mean())
output = f"chatgpt_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
df.to_csv(output, index=False)
print(f"\nSaved to: {output}")
# Show summary
print("\nResults:")
print(df['prediction'].value_counts())
def make_sample_db():
df = pd.read_csv(rf"train_datasets_creation/full_train_dataset.csv")
dfs = [df[df['label']==d].sample(100) for d in df['label'].unique()]
df = pd.concat(dfs).reset_index()
df.to_csv(f"sample_db_{datetime.now().isoformat()}.csv")
if __name__ == "__main__":
make_sample_db()