EvoTransformer-v2.1 / inference.py
HemanM's picture
Update inference.py
1c3ac0d verified
import os
import torch
from transformers import BertTokenizer
from evo_model import EvoTransformerForClassification
from openai import OpenAI
# === Load EvoTransformer model from disk ===
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()
# === BERT Tokenizer ===
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# === GPT-3.5 Setup ===
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # New SDK format
def query_gpt35(prompt):
try:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
max_tokens=50,
temperature=0.3,
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"[GPT-3.5 Error] {str(e)}"
def generate_response(goal, option1, option2):
text1 = goal + " [SEP] " + option1
text2 = goal + " [SEP] " + option2
# Tokenize both inputs
enc1 = tokenizer(text1, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
enc2 = tokenizer(text2, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
# Remove token_type_ids if present
enc1.pop("token_type_ids", None)
enc2.pop("token_type_ids", None)
# Get logits from model
with torch.no_grad():
out1 = model(**enc1)
out2 = model(**enc2)
# Compatibility: handle (loss, logits) or just logits
logits1 = out1[1] if isinstance(out1, tuple) else out1
logits2 = out2[1] if isinstance(out2, tuple) else out2
# Score = logit for class 1 (assumes class 1 = "better")
score1 = logits1[0][1].item()
score2 = logits2[0][1].item()
evo_result = option1 if score1 > score2 else option2
# GPT-3.5 fallback
prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better and why?"
gpt_result = query_gpt35(prompt)
return {
"evo_suggestion": evo_result,
"gpt_suggestion": gpt_result
}