transport_query_assistant / llava_inference.py
yunusajib's picture
dockerfile
32887b7 verified
raw
history blame contribute delete
7.56 kB
import sys
import logging
import os
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# First try to import from llava
try:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, tokenizer_image_token
logger.info("Successfully imported llava modules")
except ImportError as e:
logger.error(f"Failed to import llava modules: {e}")
sys.exit(1)
# Then import other dependencies
try:
from transformers import AutoTokenizer, AutoConfig
import torch
import requests
from PIL import Image
from io import BytesIO
logger.info("Successfully imported other required modules")
except ImportError as e:
logger.error(f"Failed to import dependency: {e}")
sys.exit(1)
class LLaVAHelper:
def __init__(self, model_name="llava-hf/llava-1.5-7b-hf"):
"""
Initialize the LLaVA model for image-text processing
"""
logger.info(f"Initializing LLaVAHelper with model: {model_name}")
# Create cache directory if it doesn't exist
os.makedirs("./model_cache", exist_ok=True)
logger.info("Created model cache directory")
# Try loading just the config to ensure the model is valid
try:
AutoConfig.from_pretrained(model_name)
logger.info(f"Successfully loaded config for {model_name}")
except Exception as e:
logger.warning(f"Error loading model config: {e}")
# Try a different model version as fallback
model_name = "llava-hf/llava-1.5-13b-hf"
logger.info(f"Trying alternative model: {model_name}")
try:
# Use specific tokenizer class to avoid issues
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir="./model_cache",
use_fast=False, # Use the Python implementation instead of the Rust one
trust_remote_code=True
)
logger.info("Tokenizer loaded successfully")
# Inspect the load_pretrained_model function to understand its parameters
import inspect
logger.info(f"load_pretrained_model signature: {inspect.signature(load_pretrained_model)}")
# Try loading with different parameter combinations
logger.info("Loading model...")
try:
# First attempt - standard parameter order
self.model, self.image_processor, _ = load_pretrained_model(
model_path=model_name,
model_base=None,
cache_dir="./model_cache",
)
except Exception as e1:
logger.warning(f"First attempt to load model failed: {e1}")
try:
# Second attempt - try with model_name parameter
self.model, self.image_processor, _ = load_pretrained_model(
model_name=model_name,
model_path=model_name,
model_base=None,
cache_dir="./model_cache",
)
except Exception as e2:
logger.warning(f"Second attempt to load model failed: {e2}")
# Third attempt - minimal parameters
self.model, self.image_processor, _ = load_pretrained_model(
model_name,
None,
"./model_cache",
)
logger.info("Model loaded successfully")
self.model.eval()
# Move model to appropriate device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
if self.device == "cpu":
# If using CPU, make sure model is in the right place
self.model = self.model.to(self.device)
logger.info(f"Model successfully loaded on {self.device}")
except Exception as e:
logger.error(f"Detailed initialization error: {e}")
logger.error("Stack trace:", exc_info=True)
raise
self.model.eval()
# Move model to appropriate device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if self.device == "cpu":
# If using CPU, make sure model is in the right place
self.model = self.model.to(self.device)
print(f"Model loaded on {self.device}")
except Exception as e:
print(f"Detailed initialization error: {e}")
raise
def generate_answer(self, image, question):
"""
Generate a response to a question about an image
Args:
image: PIL Image or path to image
question: String question about the image
Returns:
String response from the model
"""
try:
# Handle image input (either PIL Image or path/URL)
if isinstance(image, str):
if image.startswith(('http://', 'https://')):
response = requests.get(image)
image = Image.open(BytesIO(response.content))
else:
image = Image.open(image)
# Preprocess image
image_tensor = process_images(
[image],
self.image_processor,
self.model.config
)[0].unsqueeze(0).to(self.device)
# Format prompt with question
prompt = f"###Human: <image>\n{question}\n###Assistant:"
# Tokenize prompt
input_ids = tokenizer_image_token(
prompt,
self.tokenizer,
return_tensors="pt"
).to(self.device)
# Generate response
with torch.no_grad():
output_ids = self.model.generate(
input_ids=input_ids.input_ids,
images=image_tensor,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
# Decode and extract response
output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return output.split("###Assistant:")[-1].strip()
except Exception as e:
return f"Error generating answer: {str(e)}"
# Example usage if __name__ == "__main__":
if __name__ == "__main__":
try:
# Initialize model
llava = LLaVAHelper()
# Example with a local file
# response = llava.generate_answer("path/to/your/image.jpg", "What's in this image?")
# Example with a URL
# image_url = "https://example.com/image.jpg"
# response = llava.generate_answer(image_url, "Describe this image in detail.")
# print(response)
print("LLaVA model initialized successfully. Ready to process images.")
except Exception as e:
print(f"Error initializing LLaVA: {e}")