Spaces:
Runtime error
Runtime error
| from search import SemanticSearch, GoogleSearch, Document | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers import BitsAndBytesConfig | |
| from transformers.utils import is_flash_attn_2_available | |
| import yaml | |
| import torch | |
| import os # Added for environment variables | |
| import nltk | |
| def load_configs(config_file: str) -> dict: | |
| with open(config_file, "r") as f: | |
| configs = yaml.safe_load(f) | |
| return configs | |
| class RAGModel: | |
| def __init__(self, configs) -> None: | |
| self.configs = configs | |
| # 1. Get Hugging Face token (critical fix) | |
| self.hf_token = os.getenv("HUGGINGFACE_TOKEN") or configs["model"].get("hf_token") | |
| if not self.hf_token: | |
| raise ValueError( | |
| "Missing Hugging Face token! Set either:\n" | |
| "1. HUGGINGFACE_TOKEN environment variable\n" | |
| "2. hf_token in config.yml" | |
| ) | |
| # 2. Fix model URL key (typo correction) | |
| model_url = configs["model"]["generation_model"] # Fixed "genration_model" -> "generation_model" | |
| # 3. Add authentication to model loading | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_url, | |
| token=self.hf_token, # Added authentication | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| attn_implementation="sdpa", | |
| device_map="auto" # Better device handling | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_url, | |
| token=self.hf_token # Added authentication | |
| ) | |
| def create_prompt(self, query, topk_items: list[str]): | |
| context = "\n-".join(c for c in topk_items) | |
| # Improved prompt template | |
| base_prompt = f"""You are an AI search assistant. Use this context to answer: | |
| Context: {context} | |
| Question: {query} | |
| Answer in Wikipedia-style format with these requirements: | |
| - Detailed technical explanations | |
| - Historical context where relevant | |
| - Numerical data when available | |
| - Markdown formatting for structure | |
| """ | |
| dialog_template = [{"role": "user", "content": base_prompt}] | |
| # 4. Fix typo in apply_chat_template | |
| prompt = self.tokenizer.apply_chat_template( | |
| conversation=dialog_template, | |
| tokenize=False, | |
| add_generation_prompt=True # Fixed "feneration" -> "generation" | |
| ) | |
| return prompt | |
| def answer_query(self, query: str, topk_items: list[str]): | |
| prompt = self.create_prompt(query, topk_items) | |
| input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| # Improved generation parameters | |
| output = self.model.generate( | |
| **input_ids, | |
| temperature=0.7, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| top_p=0.9, | |
| repetition_penalty=1.1 | |
| ) | |
| # Better text cleanup | |
| text = self.tokenizer.decode( | |
| output[0], | |
| skip_special_tokens=True, # Better than manual replace | |
| clean_up_tokenization_spaces=True | |
| ) | |
| return text | |
| if __name__ == "__main__": | |
| # Test with authentication | |
| configs = load_configs("rag.configs.yml") | |
| # Add temporary token check | |
| if "HUGGINGFACE_TOKEN" not in os.environ: | |
| raise RuntimeError("Set HUGGINGFACE_TOKEN environment variable first!") | |
| rag = RAGModel(configs) | |
| print(rag.answer_query("What's the height of Burj Khalifa?", ["Burj Khalifa is 828 meters tall"])) |