|
|
|
"""Utilities for loading the ZeroSearch simulation model and performing simulated searches.""" |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
import functools |
|
|
|
MODEL_NAME = "sunhaonlp/SearchSimulation_14B" |
|
|
|
@functools.lru_cache(maxsize=1) |
|
def _load_search_pipe(): |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
trust_remote_code=True, |
|
device_map="auto" |
|
) |
|
return pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=512, |
|
do_sample=False, |
|
temperature=0.0, |
|
) |
|
|
|
def simulate_search(query: str, k: int = 5): |
|
"""Generate *k* synthetic documents for *query*.""" |
|
pipe = _load_search_pipe() |
|
prompt = f"SearchSimulation:\nQuery: {query}\nDocuments:" |
|
outputs = pipe(prompt, num_return_sequences=k) |
|
docs = [] |
|
for o in outputs: |
|
text = o["generated_text"] |
|
docs.append(text.split("Documents:")[-1].strip()) |
|
return docs |
|
|