Spaces:
Sleeping
Sleeping
import sys | |
import logging | |
import os | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler(sys.stdout)] | |
) | |
logger = logging.getLogger(__name__) | |
# First try to import from llava | |
try: | |
from llava.model.builder import load_pretrained_model | |
from llava.mm_utils import process_images, tokenizer_image_token | |
logger.info("Successfully imported llava modules") | |
except ImportError as e: | |
logger.error(f"Failed to import llava modules: {e}") | |
sys.exit(1) | |
# Then import other dependencies | |
try: | |
from transformers import AutoTokenizer, AutoConfig | |
import torch | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
logger.info("Successfully imported other required modules") | |
except ImportError as e: | |
logger.error(f"Failed to import dependency: {e}") | |
sys.exit(1) | |
class LLaVAHelper: | |
def __init__(self, model_name="llava-hf/llava-1.5-7b-hf"): | |
""" | |
Initialize the LLaVA model for image-text processing | |
""" | |
logger.info(f"Initializing LLaVAHelper with model: {model_name}") | |
# Create cache directory if it doesn't exist | |
os.makedirs("./model_cache", exist_ok=True) | |
logger.info("Created model cache directory") | |
# Try loading just the config to ensure the model is valid | |
try: | |
AutoConfig.from_pretrained(model_name) | |
logger.info(f"Successfully loaded config for {model_name}") | |
except Exception as e: | |
logger.warning(f"Error loading model config: {e}") | |
# Try a different model version as fallback | |
model_name = "llava-hf/llava-1.5-13b-hf" | |
logger.info(f"Trying alternative model: {model_name}") | |
try: | |
# Use specific tokenizer class to avoid issues | |
logger.info("Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
cache_dir="./model_cache", | |
use_fast=False, # Use the Python implementation instead of the Rust one | |
trust_remote_code=True | |
) | |
logger.info("Tokenizer loaded successfully") | |
# Inspect the load_pretrained_model function to understand its parameters | |
import inspect | |
logger.info(f"load_pretrained_model signature: {inspect.signature(load_pretrained_model)}") | |
# Try loading with different parameter combinations | |
logger.info("Loading model...") | |
try: | |
# First attempt - standard parameter order | |
self.model, self.image_processor, _ = load_pretrained_model( | |
model_path=model_name, | |
model_base=None, | |
cache_dir="./model_cache", | |
) | |
except Exception as e1: | |
logger.warning(f"First attempt to load model failed: {e1}") | |
try: | |
# Second attempt - try with model_name parameter | |
self.model, self.image_processor, _ = load_pretrained_model( | |
model_name=model_name, | |
model_path=model_name, | |
model_base=None, | |
cache_dir="./model_cache", | |
) | |
except Exception as e2: | |
logger.warning(f"Second attempt to load model failed: {e2}") | |
# Third attempt - minimal parameters | |
self.model, self.image_processor, _ = load_pretrained_model( | |
model_name, | |
None, | |
"./model_cache", | |
) | |
logger.info("Model loaded successfully") | |
self.model.eval() | |
# Move model to appropriate device | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {self.device}") | |
if self.device == "cpu": | |
# If using CPU, make sure model is in the right place | |
self.model = self.model.to(self.device) | |
logger.info(f"Model successfully loaded on {self.device}") | |
except Exception as e: | |
logger.error(f"Detailed initialization error: {e}") | |
logger.error("Stack trace:", exc_info=True) | |
raise | |
self.model.eval() | |
# Move model to appropriate device | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
if self.device == "cpu": | |
# If using CPU, make sure model is in the right place | |
self.model = self.model.to(self.device) | |
print(f"Model loaded on {self.device}") | |
except Exception as e: | |
print(f"Detailed initialization error: {e}") | |
raise | |
def generate_answer(self, image, question): | |
""" | |
Generate a response to a question about an image | |
Args: | |
image: PIL Image or path to image | |
question: String question about the image | |
Returns: | |
String response from the model | |
""" | |
try: | |
# Handle image input (either PIL Image or path/URL) | |
if isinstance(image, str): | |
if image.startswith(('http://', 'https://')): | |
response = requests.get(image) | |
image = Image.open(BytesIO(response.content)) | |
else: | |
image = Image.open(image) | |
# Preprocess image | |
image_tensor = process_images( | |
[image], | |
self.image_processor, | |
self.model.config | |
)[0].unsqueeze(0).to(self.device) | |
# Format prompt with question | |
prompt = f"###Human: <image>\n{question}\n###Assistant:" | |
# Tokenize prompt | |
input_ids = tokenizer_image_token( | |
prompt, | |
self.tokenizer, | |
return_tensors="pt" | |
).to(self.device) | |
# Generate response | |
with torch.no_grad(): | |
output_ids = self.model.generate( | |
input_ids=input_ids.input_ids, | |
images=image_tensor, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
) | |
# Decode and extract response | |
output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return output.split("###Assistant:")[-1].strip() | |
except Exception as e: | |
return f"Error generating answer: {str(e)}" | |
# Example usage if __name__ == "__main__": | |
if __name__ == "__main__": | |
try: | |
# Initialize model | |
llava = LLaVAHelper() | |
# Example with a local file | |
# response = llava.generate_answer("path/to/your/image.jpg", "What's in this image?") | |
# Example with a URL | |
# image_url = "https://example.com/image.jpg" | |
# response = llava.generate_answer(image_url, "Describe this image in detail.") | |
# print(response) | |
print("LLaVA model initialized successfully. Ready to process images.") | |
except Exception as e: | |
print(f"Error initializing LLaVA: {e}") |