Spaces:
Running
Running
import os | |
import torch | |
from transformers import BertTokenizer | |
from evo_model import EvoTransformerForClassification | |
from openai import OpenAI | |
# === Load EvoTransformer model from disk === | |
model = EvoTransformerForClassification.from_pretrained("trained_model") | |
model.eval() | |
# === BERT Tokenizer === | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
# === GPT-3.5 Setup === | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # New SDK format | |
def query_gpt35(prompt): | |
try: | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=50, | |
temperature=0.3, | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
return f"[GPT-3.5 Error] {str(e)}" | |
def generate_response(goal, option1, option2): | |
text1 = goal + " [SEP] " + option1 | |
text2 = goal + " [SEP] " + option2 | |
# Tokenize both inputs | |
enc1 = tokenizer(text1, return_tensors="pt", padding="max_length", truncation=True, max_length=128) | |
enc2 = tokenizer(text2, return_tensors="pt", padding="max_length", truncation=True, max_length=128) | |
# Remove token_type_ids if present | |
enc1.pop("token_type_ids", None) | |
enc2.pop("token_type_ids", None) | |
# Get logits from model | |
with torch.no_grad(): | |
out1 = model(**enc1) | |
out2 = model(**enc2) | |
# Compatibility: handle (loss, logits) or just logits | |
logits1 = out1[1] if isinstance(out1, tuple) else out1 | |
logits2 = out2[1] if isinstance(out2, tuple) else out2 | |
# Score = logit for class 1 (assumes class 1 = "better") | |
score1 = logits1[0][1].item() | |
score2 = logits2[0][1].item() | |
evo_result = option1 if score1 > score2 else option2 | |
# GPT-3.5 fallback | |
prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better and why?" | |
gpt_result = query_gpt35(prompt) | |
return { | |
"evo_suggestion": evo_result, | |
"gpt_suggestion": gpt_result | |
} | |