Spaces:
Running
Running
File size: 2,044 Bytes
c3ba93a a457e2e b530936 1c3ac0d 4e96bf5 c3ba93a b530936 4e96bf5 c3ba93a a457e2e c3ba93a 1c3ac0d 854864a a457e2e 1c3ac0d a457e2e 3bfaa31 a457e2e c3ba93a 854864a a457e2e c3ba93a e33deda c3ba93a e33deda c3ba93a 981d63b c3ba93a 4e96bf5 c3ba93a e33deda 1c3ac0d c3ba93a e33deda a457e2e c3ba93a a457e2e 1c3ac0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import os
import torch
from transformers import BertTokenizer
from evo_model import EvoTransformerForClassification
from openai import OpenAI
# === Load EvoTransformer model from disk ===
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()
# === BERT Tokenizer ===
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# === GPT-3.5 Setup ===
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # New SDK format
def query_gpt35(prompt):
try:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
max_tokens=50,
temperature=0.3,
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"[GPT-3.5 Error] {str(e)}"
def generate_response(goal, option1, option2):
text1 = goal + " [SEP] " + option1
text2 = goal + " [SEP] " + option2
# Tokenize both inputs
enc1 = tokenizer(text1, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
enc2 = tokenizer(text2, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
# Remove token_type_ids if present
enc1.pop("token_type_ids", None)
enc2.pop("token_type_ids", None)
# Get logits from model
with torch.no_grad():
out1 = model(**enc1)
out2 = model(**enc2)
# Compatibility: handle (loss, logits) or just logits
logits1 = out1[1] if isinstance(out1, tuple) else out1
logits2 = out2[1] if isinstance(out2, tuple) else out2
# Score = logit for class 1 (assumes class 1 = "better")
score1 = logits1[0][1].item()
score2 = logits2[0][1].item()
evo_result = option1 if score1 > score2 else option2
# GPT-3.5 fallback
prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better and why?"
gpt_result = query_gpt35(prompt)
return {
"evo_suggestion": evo_result,
"gpt_suggestion": gpt_result
}
|