ZIISA2 / model.py
akashmishra358's picture
Update model.py
b8298df verified
from search import SemanticSearch, GoogleSearch, Document
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from transformers.utils import is_flash_attn_2_available
import yaml
import torch
import os # Added for environment variables
import nltk
def load_configs(config_file: str) -> dict:
with open(config_file, "r") as f:
configs = yaml.safe_load(f)
return configs
class RAGModel:
def __init__(self, configs) -> None:
self.configs = configs
# 1. Get Hugging Face token (critical fix)
self.hf_token = os.getenv("HUGGINGFACE_TOKEN") or configs["model"].get("hf_token")
if not self.hf_token:
raise ValueError(
"Missing Hugging Face token! Set either:\n"
"1. HUGGINGFACE_TOKEN environment variable\n"
"2. hf_token in config.yml"
)
# 2. Fix model URL key (typo correction)
model_url = configs["model"]["generation_model"] # Fixed "genration_model" -> "generation_model"
# 3. Add authentication to model loading
self.model = AutoModelForCausalLM.from_pretrained(
model_url,
token=self.hf_token, # Added authentication
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="sdpa",
device_map="auto" # Better device handling
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_url,
token=self.hf_token # Added authentication
)
def create_prompt(self, query, topk_items: list[str]):
context = "\n-".join(c for c in topk_items)
# Improved prompt template
base_prompt = f"""You are an AI search assistant. Use this context to answer:
Context: {context}
Question: {query}
Answer in Wikipedia-style format with these requirements:
- Detailed technical explanations
- Historical context where relevant
- Numerical data when available
- Markdown formatting for structure
"""
dialog_template = [{"role": "user", "content": base_prompt}]
# 4. Fix typo in apply_chat_template
prompt = self.tokenizer.apply_chat_template(
conversation=dialog_template,
tokenize=False,
add_generation_prompt=True # Fixed "feneration" -> "generation"
)
return prompt
def answer_query(self, query: str, topk_items: list[str]):
prompt = self.create_prompt(query, topk_items)
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Improved generation parameters
output = self.model.generate(
**input_ids,
temperature=0.7,
max_new_tokens=1024,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1
)
# Better text cleanup
text = self.tokenizer.decode(
output[0],
skip_special_tokens=True, # Better than manual replace
clean_up_tokenization_spaces=True
)
return text
if __name__ == "__main__":
# Test with authentication
configs = load_configs("rag.configs.yml")
# Add temporary token check
if "HUGGINGFACE_TOKEN" not in os.environ:
raise RuntimeError("Set HUGGINGFACE_TOKEN environment variable first!")
rag = RAGModel(configs)
print(rag.answer_query("What's the height of Burj Khalifa?", ["Burj Khalifa is 828 meters tall"]))