| from typing import Dict, List, Any |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| from peft import PeftModel |
| import re |
| import os |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| """ |
| Initialize the model and tokenizer for the inference endpoint. |
| |
| Args: |
| path: The path to the model directory (provided by HF Inference Endpoints) |
| """ |
| |
| self.base_model_name = "meta-llama/Llama-3.1-8B-Instruct" |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| hf_token = os.environ.get("HF_TOKEN", None) |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| self.base_model_name, |
| token=hf_token, |
| trust_remote_code=True |
| ) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| |
| if torch.cuda.is_available(): |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.float16, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4" |
| ) |
| base_model = AutoModelForCausalLM.from_pretrained( |
| self.base_model_name, |
| quantization_config=bnb_config, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| trust_remote_code=True, |
| token=hf_token |
| ) |
| else: |
| base_model = AutoModelForCausalLM.from_pretrained( |
| self.base_model_name, |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True, |
| token=hf_token |
| ) |
| |
| |
| self.model = PeftModel.from_pretrained(base_model, path) |
| self.model.eval() |
| |
| |
| self.generation_config = { |
| "do_sample": True, |
| "temperature": 0.7, |
| "top_p": 0.9, |
| "max_new_tokens": 1000, |
| "pad_token_id": self.tokenizer.pad_token_id, |
| "eos_token_id": self.tokenizer.eos_token_id |
| } |
| |
| def format_math_prompt(self, question: str) -> str: |
| """Format a math question with proper instructions.""" |
| instructions = """Please solve this math problem step by step, following these rules: |
| 1) Start by noting all the facts from the problem. |
| 2) Show your work by performing inner calculations inside double angle brackets, like <<calculation=result>>. |
| 3) You MUST write the final answer on a new line with a #### prefix. |
| Note - each answer must be of length <= 400.""" |
| |
| |
| prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{instructions}<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n" |
| return prompt |
| |
| def extract_answer(self, response: str) -> Any: |
| """Extract the final answer from the model response.""" |
| |
| answer_match = re.search(r'####\s*([-\d,\.]+)', response) |
| if answer_match: |
| answer_str = answer_match.group(1).replace(',', '') |
| try: |
| |
| if '.' in answer_str: |
| return float(answer_str) |
| else: |
| return int(answer_str) |
| except ValueError: |
| return answer_str |
| |
| |
| numbers = re.findall(r'[-\d,\.]+', response) |
| if numbers: |
| last_num = numbers[-1].replace(',', '') |
| try: |
| if '.' in last_num: |
| return float(last_num) |
| else: |
| return int(last_num) |
| except ValueError: |
| pass |
| |
| return None |
| |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| Process the inference request. |
| |
| Args: |
| data: A dictionary containing the input data |
| - inputs: str or List[str] - The math questions to solve |
| - parameters (optional): Dict with generation parameters |
| |
| Returns: |
| List of dictionaries containing the results |
| """ |
| |
| inputs = data.get("inputs", "") |
| parameters = data.get("parameters", {}) |
| |
| |
| if isinstance(inputs, str): |
| questions = [inputs] |
| else: |
| questions = inputs |
| |
| |
| gen_config = self.generation_config.copy() |
| gen_config.update(parameters) |
| |
| |
| results = [] |
| for question in questions: |
| |
| prompt = self.format_math_prompt(question) |
| |
| |
| model_inputs = self.tokenizer( |
| prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512 |
| ).to(self.device) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **model_inputs, |
| **gen_config |
| ) |
| |
| |
| input_length = model_inputs['input_ids'].shape[1] |
| generated_tokens = outputs[0][input_length:] |
| assistant_response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() |
| |
| |
| extracted_answer = self.extract_answer(assistant_response) |
| |
| results.append({ |
| "question": question, |
| "full_response": assistant_response, |
| "answer": extracted_answer, |
| "formatted_answer": f"#### {extracted_answer}" if extracted_answer is not None else "No answer found" |
| }) |
| |
| return results |