Spaces:
Sleeping
Sleeping
File size: 5,273 Bytes
26d1a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from typing import List, Dict, Any
import gc
class ResponseGenerator:
def __init__(self, model_name: str = "mistralai/Mistral-7B-Instruct-v0.1"):
"""
Initialize the response generator with an LLM
Optimized for 8-11GB GPU
"""
print(f"Loading LLM: {model_name}")
print("This may take a few minutes...")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Configure device and data type based on available resources
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Free up memory before loading model
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
# Configure 4-bit quantization for maximum memory efficiency
try:
# Use 4-bit quantization for models that support it
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
# Load the model with quantization
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.float16,
# Load model in parts to avoid OOM errors
max_memory={0: "8GiB", "cpu": "16GiB"},
offload_folder="offload",
offload_state_dict=True, # Offload weights to CPU when not in use
low_cpu_mem_usage=True
)
except Exception as e:
print(f"4-bit quantization error: {e}")
print("Falling back to 8-bit quantization...")
try:
# Try 8-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.float16,
max_memory={0: "8GiB", "cpu": "16GiB"},
offload_folder="offload",
low_cpu_mem_usage=True
)
except Exception as e2:
print(f"8-bit quantization error: {e2}")
print("Falling back to smaller model...")
# Use a much smaller model as fallback
backup_model = "microsoft/phi-2"
self.tokenizer = AutoTokenizer.from_pretrained(backup_model)
self.model = AutoModelForCausalLM.from_pretrained(
backup_model,
device_map="auto",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
print("LLM loaded successfully")
def generate_response(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str:
"""
Generate a response using the LLM with retrieved FAQs as context
Memory-optimized version
"""
# Create prompt with relevant FAQs
prompt = self._create_prompt(query, relevant_faqs)
# Generate response with memory-efficient settings
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
# Use more conservative generation parameters
outputs = self.model.generate(
**inputs,
max_new_tokens=200, # Shorter response for memory efficiency
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the response part (after the prompt)
response = response[len(prompt):].strip()
# Clear GPU memory after generating response
if torch.cuda.is_available():
torch.cuda.empty_cache()
return response
def _create_prompt(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str:
"""
Create a prompt for the LLM with retrieved FAQs as context
"""
# Format FAQs in a way that's suitable for the model
faq_context = "\n\n".join([
f"Q: {faq['question']}\nA: {faq['answer']}"
for faq in relevant_faqs
])
# Create the prompt
prompt = f"""
Below are some relevant e-commerce customer support FAQ entries:
{faq_context}
Based on the information above, please provide a helpful, accurate, and concise response to the following customer query:
Customer Query: {query}
Response:
"""
return prompt |