agent_tuning_framework / llm_interface.py
swaleha19's picture
Upload 13 files
6c482f9 verified
"""
LLM Interface Module for Cross-Domain Uncertainty Quantification
This module provides a unified interface for interacting with large language models,
supporting multiple model architectures and uncertainty quantification methods.
"""
import torch
import numpy as np
from typing import List, Dict, Any, Union, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm
class LLMInterface:
"""Interface for interacting with large language models with uncertainty quantification."""
def __init__(
self,
model_name: str,
model_type: str = "causal",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
cache_dir: Optional[str] = None,
max_length: int = 512,
temperature: float = 1.0,
top_p: float = 1.0,
num_beams: int = 1
):
"""
Initialize the LLM interface.
Args:
model_name: Name of the Hugging Face model to use
model_type: Type of model ('causal' or 'seq2seq')
device: Device to run the model on ('cpu' or 'cuda')
cache_dir: Directory to cache models
max_length: Maximum length of generated sequences
temperature: Sampling temperature
top_p: Nucleus sampling parameter
num_beams: Number of beams for beam search
"""
self.model_name = model_name
self.model_type = model_type
self.device = device
self.cache_dir = cache_dir
self.max_length = max_length
self.temperature = temperature
self.top_p = top_p
self.num_beams = num_beams
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=cache_dir
)
# Load model based on type
if model_type == "causal":
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=cache_dir,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
elif model_type == "seq2seq":
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
cache_dir=cache_dir,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
else:
raise ValueError(f"Unsupported model type: {model_type}")
# Response cache for efficiency
self.response_cache = {}
def generate(
self,
prompt: str,
num_samples: int = 1,
return_logits: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
Generate responses from the model with uncertainty quantification.
Args:
prompt: Input text prompt
num_samples: Number of samples to generate (for MC methods)
return_logits: Whether to return token logits
**kwargs: Additional generation parameters
Returns:
Dictionary containing:
- response: The generated text
- samples: Multiple samples if num_samples > 1
- logits: Token logits if return_logits is True
"""
# Check cache first
cache_key = (prompt, num_samples, return_logits, str(kwargs))
if cache_key in self.response_cache:
return self.response_cache[cache_key]
# Prepare inputs
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Set generation parameters
gen_kwargs = {
"max_length": self.max_length,
"temperature": self.temperature,
"top_p": self.top_p,
"num_beams": self.num_beams,
"do_sample": self.temperature > 0,
"pad_token_id": self.tokenizer.eos_token_id
}
gen_kwargs.update(kwargs)
# Generate multiple samples if requested
samples = []
all_logits = []
for _ in range(num_samples):
with torch.no_grad():
outputs = self.model.generate(
**inputs,
output_scores=return_logits,
return_dict_in_generate=True,
**gen_kwargs
)
# Extract generated tokens
if self.model_type == "causal":
gen_tokens = outputs.sequences[0, inputs.input_ids.shape[1]:]
else:
gen_tokens = outputs.sequences[0]
# Decode tokens to text
gen_text = self.tokenizer.decode(gen_tokens, skip_special_tokens=True)
samples.append(gen_text)
# Extract logits if requested
if return_logits and hasattr(outputs, "scores"):
all_logits.append([score.cpu().numpy() for score in outputs.scores])
# Prepare result
result = {
"response": samples[0], # Primary response is first sample
"samples": samples
}
if return_logits:
result["logits"] = all_logits
# Cache result
self.response_cache[cache_key] = result
return result
def batch_generate(
self,
prompts: List[str],
**kwargs
) -> List[Dict[str, Any]]:
"""
Generate responses for a batch of prompts.
Args:
prompts: List of input text prompts
**kwargs: Additional generation parameters
Returns:
List of generation results for each prompt
"""
results = []
for prompt in tqdm(prompts, desc="Generating responses"):
results.append(self.generate(prompt, **kwargs))
return results
def clear_cache(self):
"""Clear the response cache."""
self.response_cache = {}