FelixPhilip's picture
Oracle weight assigning update
ba26d2b
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class SmolLM:
def __init__(self, model_path="HuggingFaceTB/SmolLM2-1.7B-Instruct"):
self.available = True
self.device = "cuda" if torch.cuda.is_available() else "cpu"
try:
print(f"[INFO] Loading Oracle tokenizer from {model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
print(f"[INFO] Loading Oracle from {model_path} on {self.device}")
self.model = AutoModelForCausalLM.from_pretrained(model_path).to(self.device)
print("[INFO] Oracle loaded successfully")
except Exception as e:
print(f"[ERROR] Failed to load model '{model_path}': {e}")
self.available = False
def predict(self, prompt, max_length=512, max_new_tokens=150):
if not self.available:
print("[WARN] Oracle unavailable, returning default weight 0.5")
return "0.5"
try:
# Use chat template as per documentation
messages = [{"role": "user", "content": prompt}]
inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device)
outputs = self.model.generate(
inputs,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
do_sample=True
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"[INFO] Generated response: {response[:100]}...", flush=True)
return response
except Exception as e:
print(f"[ERROR] Oracle has failed: {e}")
return "0.5"