faq-rag-chatbot / src /llm_response.py
Techbite's picture
initial commit
26d1a81
raw
history blame
5.27 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from typing import List, Dict, Any
import gc
class ResponseGenerator:
def __init__(self, model_name: str = "mistralai/Mistral-7B-Instruct-v0.1"):
"""
Initialize the response generator with an LLM
Optimized for 8-11GB GPU
"""
print(f"Loading LLM: {model_name}")
print("This may take a few minutes...")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Configure device and data type based on available resources
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Free up memory before loading model
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
# Configure 4-bit quantization for maximum memory efficiency
try:
# Use 4-bit quantization for models that support it
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
# Load the model with quantization
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.float16,
# Load model in parts to avoid OOM errors
max_memory={0: "8GiB", "cpu": "16GiB"},
offload_folder="offload",
offload_state_dict=True, # Offload weights to CPU when not in use
low_cpu_mem_usage=True
)
except Exception as e:
print(f"4-bit quantization error: {e}")
print("Falling back to 8-bit quantization...")
try:
# Try 8-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.float16,
max_memory={0: "8GiB", "cpu": "16GiB"},
offload_folder="offload",
low_cpu_mem_usage=True
)
except Exception as e2:
print(f"8-bit quantization error: {e2}")
print("Falling back to smaller model...")
# Use a much smaller model as fallback
backup_model = "microsoft/phi-2"
self.tokenizer = AutoTokenizer.from_pretrained(backup_model)
self.model = AutoModelForCausalLM.from_pretrained(
backup_model,
device_map="auto",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
print("LLM loaded successfully")
def generate_response(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str:
"""
Generate a response using the LLM with retrieved FAQs as context
Memory-optimized version
"""
# Create prompt with relevant FAQs
prompt = self._create_prompt(query, relevant_faqs)
# Generate response with memory-efficient settings
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
# Use more conservative generation parameters
outputs = self.model.generate(
**inputs,
max_new_tokens=200, # Shorter response for memory efficiency
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the response part (after the prompt)
response = response[len(prompt):].strip()
# Clear GPU memory after generating response
if torch.cuda.is_available():
torch.cuda.empty_cache()
return response
def _create_prompt(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str:
"""
Create a prompt for the LLM with retrieved FAQs as context
"""
# Format FAQs in a way that's suitable for the model
faq_context = "\n\n".join([
f"Q: {faq['question']}\nA: {faq['answer']}"
for faq in relevant_faqs
])
# Create the prompt
prompt = f"""
Below are some relevant e-commerce customer support FAQ entries:
{faq_context}
Based on the information above, please provide a helpful, accurate, and concise response to the following customer query:
Customer Query: {query}
Response:
"""
return prompt