| """ |
| Custom handler for ConvSearch-R1 query rewriting on HuggingFace Inference Endpoints. |
| |
| Accepts conversation context + query and returns a rewritten query. |
| """ |
|
|
| import re |
| import torch |
| from typing import Any, Dict, List, Union |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| PROMPT_TEMPLATE = """Given a query and its context, you must first think about the reasoning process in the mind to decontextualize the query by resolving \ |
| coreference and omission issues. Then, provide the user with a rewrite that retains its original meaning and is as informative as possible to help \ |
| search engines retrieve relevant documents effectively. The reasoning process and rewrite should be enclosed within <think> </think> and <rewrite> </rewrite> \ |
| tags, respectively, i.e., <think> reasoning process here </think> |
| <rewrite> rewrite here </rewrite>. |
| |
| ### Context Begin ### |
| {context} |
| ### Context End ### |
| |
| Query: {query} |
| Rewrite:""" |
|
|
|
|
| class EndpointHandler: |
| """Handler for ConvSearch-R1 query rewriting.""" |
|
|
| def __init__(self, path: str = ""): |
| """Initialize the model and tokenizer.""" |
| self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| path, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True, |
| ) |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = self.model.to(self.device) |
| self.model.eval() |
| print(f"ConvSearch-R1 loaded on {self.device}") |
|
|
| def _format_prompt(self, context: List[str], query: str) -> str: |
| """Format conversation context and query into the model prompt.""" |
| ctx_lines = [] |
| for i in range(0, len(context), 2): |
| turn = i // 2 + 1 |
| ctx_lines.append(f"Q{turn}: {context[i]}") |
| if i + 1 < len(context): |
| ctx_lines.append(f"A{turn}: {context[i + 1]}") |
| return PROMPT_TEMPLATE.format( |
| context="\n".join(ctx_lines), |
| query=query, |
| ) |
|
|
| def _extract_rewrite(self, output: str) -> str: |
| """Extract rewrite from model output.""" |
| match = re.search(r"<rewrite>(.*?)</rewrite>", output, re.DOTALL) |
| if match: |
| return match.group(1).strip() |
| return output.strip() |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| Process inference requests. |
| |
| Input format: |
| { |
| "inputs": [ |
| {"context": ["Q1", "A1", "Q2", "A2"], "query": "current question"}, |
| ... |
| ], |
| "parameters": {"temperature": 0.7, "max_new_tokens": 1024} |
| } |
| |
| Or single input: |
| { |
| "inputs": {"context": [...], "query": "..."}, |
| "parameters": {...} |
| } |
| |
| Returns: |
| [{"rewrite": "rewritten query", "raw_output": "full model output"}, ...] |
| """ |
| inputs = data.get("inputs", data) |
| params = data.get("parameters", {}) |
| temperature = params.get("temperature", 0.7) |
| max_new_tokens = params.get("max_new_tokens", 4096) |
|
|
| |
| if isinstance(inputs, dict): |
| inputs = [inputs] |
|
|
| results = [] |
| for inp in inputs: |
| context = inp.get("context", []) |
| query = inp.get("query", "") |
|
|
| |
| prompt_text = self._format_prompt(context, query) |
| messages = [{"role": "user", "content": prompt_text}] |
| formatted = self.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| |
| tokens = self.tokenizer( |
| formatted, return_tensors="pt", truncation=True, max_length=2048 |
| ).to(self.device) |
|
|
| |
| with torch.no_grad(): |
| output_ids = self.model.generate( |
| **tokens, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=1.0, |
| do_sample=temperature > 0, |
| ) |
|
|
| |
| new_tokens = output_ids[0][tokens["input_ids"].shape[1]:] |
| raw_output = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
| |
| rewrite = self._extract_rewrite(raw_output) |
| results.append({ |
| "rewrite": rewrite, |
| "raw_output": raw_output, |
| }) |
|
|
| return results |
|
|