Spaces:
Sleeping
Sleeping
File size: 1,848 Bytes
d3d00df 785c4f7 453eba8 11f2d5b 21afb35 11f2d5b 453eba8 11f2d5b 453eba8 11f2d5b 453eba8 7d3dbed 11f2d5b 21afb35 11f2d5b 453eba8 11f2d5b 7d3dbed 11f2d5b 7d3dbed 11f2d5b 453eba8 11f2d5b 7d3dbed 11f2d5b 453eba8 11f2d5b 453eba8 11f2d5b 453eba8 11f2d5b 21afb35 11f2d5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import os
import torch
from evo_model import EvoTransformer
from transformers import AutoTokenizer
from rag_utils import extract_text_from_file
from search_utils import web_search
# Load Evo model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = EvoTransformer()
model.load_state_dict(torch.load("evo_hellaswag.pt", map_location="cpu"))
model.eval()
def get_evo_response(query, context=None, enable_search=True):
search_snippets = ""
if enable_search:
snippets = web_search(query)
if snippets:
search_snippets = "\n".join(snippets)
full_context = f"{context or ''}\n\n{search_snippets}".strip()
input_1 = f"{query} Option 1"
input_2 = f"{query} Option 2"
inputs = tokenizer([input_1, input_2], padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
logits = model(inputs["input_ids"]).squeeze(-1)
probs = torch.softmax(logits, dim=0)
best_idx = torch.argmax(probs).item()
suggestion = f"Option {best_idx + 1}"
reasoning = (
f"Evo suggests: **{suggestion}** (Confidence: {probs[best_idx]:.2f})\n\n"
f"Context used:\n{full_context}"
)
return suggestion, reasoning
def get_gpt_response(query, context=None):
import openai
openai.api_key = os.getenv("OPENAI_API_KEY", "")
context = context or "None"
try:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful expert advisor."},
{"role": "user", "content": f"Context: {context}\n\nQuestion: {query}"}
],
max_tokens=250
)
return response["choices"][0]["message"]["content"].strip()
except Exception as e:
return f"⚠️ GPT error: {str(e)}"
|