File size: 1,745 Bytes
785c4f7
11f2d5b
dd9d40e
5b3d26d
7a7ebad
e87c15e
453eba8
5b3d26d
6a1a7af
11f2d5b
453eba8
7d3dbed
5b3d26d
7a7ebad
5b3d26d
e87c15e
7a7ebad
5b3d26d
 
 
 
7a7ebad
5b3d26d
 
 
 
 
 
 
 
 
 
 
 
 
 
7a7ebad
5b3d26d
 
 
 
e87c15e
5b3d26d
e87c15e
 
 
 
5b3d26d
e87c15e
5b3d26d
e87c15e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from transformers import AutoTokenizer
from evo_model import EvoTransformerV22 
from search_utils import web_search
import openai
import os

# Load Evo model and tokenizer
model = EvoTransformerV22()
model.load_state_dict(torch.load("evo_hellaswag.pt", map_location="cpu"))
model.eval()

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# GPT Setup
openai.api_key = os.getenv("OPENAI_API_KEY")  # 🔒 Load securely from environment

def get_evo_response(query, options, user_context=""):
    context_texts = web_search(query) + ([user_context] if user_context else [])
    context_str = "\n".join(context_texts)
    input_pairs = [f"{query} [SEP] {opt} [CTX] {context_str}" for opt in options]

    scores = []
    for pair in input_pairs:
        encoded = tokenizer(pair, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
        with torch.no_grad():
            output = model(encoded["input_ids"])
            score = torch.sigmoid(output).item()
            scores.append(score)

    best_idx = int(scores[1] > scores[0])
    return (
        options[best_idx],
        f"{options[0]}: {scores[0]:.3f} vs {options[1]}: {scores[1]:.3f}",
        max(scores),
        context_str
    )

def get_gpt_response(query, user_context=""):
    try:
        context_block = f"\n\nContext:\n{user_context}" if user_context else ""
        response = openai.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "user", "content": query + context_block}
            ],
            temperature=0.7,
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"⚠️ GPT error:\n\n{str(e)}"