File size: 2,044 Bytes
c3ba93a
 
a457e2e
b530936
1c3ac0d
4e96bf5
c3ba93a
b530936
4e96bf5
 
c3ba93a
a457e2e
 
c3ba93a
1c3ac0d
854864a
a457e2e
 
1c3ac0d
a457e2e
 
 
 
 
3bfaa31
a457e2e
c3ba93a
854864a
a457e2e
c3ba93a
 
e33deda
c3ba93a
 
 
e33deda
c3ba93a
 
 
981d63b
c3ba93a
4e96bf5
c3ba93a
 
 
 
 
 
e33deda
1c3ac0d
c3ba93a
 
e33deda
 
a457e2e
c3ba93a
 
a457e2e
 
 
 
 
 
1c3ac0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import torch
from transformers import BertTokenizer
from evo_model import EvoTransformerForClassification
from openai import OpenAI

# === Load EvoTransformer model from disk ===
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()

# === BERT Tokenizer ===
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# === GPT-3.5 Setup ===
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))  # New SDK format

def query_gpt35(prompt):
    try:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": prompt}],
            max_tokens=50,
            temperature=0.3,
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"[GPT-3.5 Error] {str(e)}"

def generate_response(goal, option1, option2):
    text1 = goal + " [SEP] " + option1
    text2 = goal + " [SEP] " + option2

    # Tokenize both inputs
    enc1 = tokenizer(text1, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
    enc2 = tokenizer(text2, return_tensors="pt", padding="max_length", truncation=True, max_length=128)

    # Remove token_type_ids if present
    enc1.pop("token_type_ids", None)
    enc2.pop("token_type_ids", None)

    # Get logits from model
    with torch.no_grad():
        out1 = model(**enc1)
        out2 = model(**enc2)

    # Compatibility: handle (loss, logits) or just logits
    logits1 = out1[1] if isinstance(out1, tuple) else out1
    logits2 = out2[1] if isinstance(out2, tuple) else out2

    # Score = logit for class 1 (assumes class 1 = "better")
    score1 = logits1[0][1].item()
    score2 = logits2[0][1].item()

    evo_result = option1 if score1 > score2 else option2

    # GPT-3.5 fallback
    prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better and why?"
    gpt_result = query_gpt35(prompt)

    return {
        "evo_suggestion": evo_result,
        "gpt_suggestion": gpt_result
    }