|
""" |
|
LLM Interface Module for Cross-Domain Uncertainty Quantification |
|
|
|
This module provides a unified interface for interacting with large language models, |
|
supporting multiple model architectures and uncertainty quantification methods. |
|
""" |
|
|
|
import torch |
|
import numpy as np |
|
from typing import List, Dict, Any, Union, Optional |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM |
|
from tqdm import tqdm |
|
|
|
class LLMInterface: |
|
"""Interface for interacting with large language models with uncertainty quantification.""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str, |
|
model_type: str = "causal", |
|
device: str = "cuda" if torch.cuda.is_available() else "cpu", |
|
cache_dir: Optional[str] = None, |
|
max_length: int = 512, |
|
temperature: float = 1.0, |
|
top_p: float = 1.0, |
|
num_beams: int = 1 |
|
): |
|
""" |
|
Initialize the LLM interface. |
|
|
|
Args: |
|
model_name: Name of the Hugging Face model to use |
|
model_type: Type of model ('causal' or 'seq2seq') |
|
device: Device to run the model on ('cpu' or 'cuda') |
|
cache_dir: Directory to cache models |
|
max_length: Maximum length of generated sequences |
|
temperature: Sampling temperature |
|
top_p: Nucleus sampling parameter |
|
num_beams: Number of beams for beam search |
|
""" |
|
self.model_name = model_name |
|
self.model_type = model_type |
|
self.device = device |
|
self.cache_dir = cache_dir |
|
self.max_length = max_length |
|
self.temperature = temperature |
|
self.top_p = top_p |
|
self.num_beams = num_beams |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
cache_dir=cache_dir |
|
) |
|
|
|
|
|
if model_type == "causal": |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
cache_dir=cache_dir, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
|
).to(device) |
|
elif model_type == "seq2seq": |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name, |
|
cache_dir=cache_dir, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
|
).to(device) |
|
else: |
|
raise ValueError(f"Unsupported model type: {model_type}") |
|
|
|
|
|
self.response_cache = {} |
|
|
|
def generate( |
|
self, |
|
prompt: str, |
|
num_samples: int = 1, |
|
return_logits: bool = False, |
|
**kwargs |
|
) -> Dict[str, Any]: |
|
""" |
|
Generate responses from the model with uncertainty quantification. |
|
|
|
Args: |
|
prompt: Input text prompt |
|
num_samples: Number of samples to generate (for MC methods) |
|
return_logits: Whether to return token logits |
|
**kwargs: Additional generation parameters |
|
|
|
Returns: |
|
Dictionary containing: |
|
- response: The generated text |
|
- samples: Multiple samples if num_samples > 1 |
|
- logits: Token logits if return_logits is True |
|
""" |
|
|
|
cache_key = (prompt, num_samples, return_logits, str(kwargs)) |
|
if cache_key in self.response_cache: |
|
return self.response_cache[cache_key] |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
gen_kwargs = { |
|
"max_length": self.max_length, |
|
"temperature": self.temperature, |
|
"top_p": self.top_p, |
|
"num_beams": self.num_beams, |
|
"do_sample": self.temperature > 0, |
|
"pad_token_id": self.tokenizer.eos_token_id |
|
} |
|
gen_kwargs.update(kwargs) |
|
|
|
|
|
samples = [] |
|
all_logits = [] |
|
|
|
for _ in range(num_samples): |
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
**inputs, |
|
output_scores=return_logits, |
|
return_dict_in_generate=True, |
|
**gen_kwargs |
|
) |
|
|
|
|
|
if self.model_type == "causal": |
|
gen_tokens = outputs.sequences[0, inputs.input_ids.shape[1]:] |
|
else: |
|
gen_tokens = outputs.sequences[0] |
|
|
|
|
|
gen_text = self.tokenizer.decode(gen_tokens, skip_special_tokens=True) |
|
samples.append(gen_text) |
|
|
|
|
|
if return_logits and hasattr(outputs, "scores"): |
|
all_logits.append([score.cpu().numpy() for score in outputs.scores]) |
|
|
|
|
|
result = { |
|
"response": samples[0], |
|
"samples": samples |
|
} |
|
|
|
if return_logits: |
|
result["logits"] = all_logits |
|
|
|
|
|
self.response_cache[cache_key] = result |
|
return result |
|
|
|
def batch_generate( |
|
self, |
|
prompts: List[str], |
|
**kwargs |
|
) -> List[Dict[str, Any]]: |
|
""" |
|
Generate responses for a batch of prompts. |
|
|
|
Args: |
|
prompts: List of input text prompts |
|
**kwargs: Additional generation parameters |
|
|
|
Returns: |
|
List of generation results for each prompt |
|
""" |
|
results = [] |
|
for prompt in tqdm(prompts, desc="Generating responses"): |
|
results.append(self.generate(prompt, **kwargs)) |
|
return results |
|
|
|
def clear_cache(self): |
|
"""Clear the response cache.""" |
|
self.response_cache = {} |
|
|