Spaces:
Sleeping
Sleeping
""" | |
LLaVA model implementation. | |
""" | |
import torch | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from PIL import Image | |
from ..configs.settings import MODEL_NAME, MODEL_REVISION, DEVICE | |
from ..utils.logging import get_logger | |
logger = get_logger(__name__) | |
class LLaVAModel: | |
"""LLaVA model wrapper class.""" | |
def __init__(self): | |
"""Initialize the LLaVA model and processor.""" | |
try: | |
logger.info(f"Initializing LLaVA model from {MODEL_NAME}") | |
logger.info(f"Using device: {DEVICE}") | |
# Initialize processor | |
self.processor = AutoProcessor.from_pretrained( | |
MODEL_NAME, | |
revision=MODEL_REVISION, | |
trust_remote_code=True | |
) | |
# Set model dtype based on device | |
model_dtype = torch.float32 if DEVICE == "cpu" else torch.float16 | |
# Initialize model with appropriate settings | |
self.model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
revision=MODEL_REVISION, | |
torch_dtype=model_dtype, | |
device_map="auto" if DEVICE == "cuda" else None, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
# Move model to device if not using device_map | |
if DEVICE == "cpu": | |
self.model = self.model.to(DEVICE) | |
logger.info("Model initialization complete") | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}") | |
raise | |
def generate_response( | |
self, | |
image: Image.Image, | |
prompt: str, | |
max_new_tokens: int = 512, | |
temperature: float = 0.7, | |
top_p: float = 0.9 | |
) -> str: | |
""" | |
Generate a response for the given image and prompt. | |
Args: | |
image: Input image as PIL Image | |
prompt: Text prompt for the model | |
max_new_tokens: Maximum number of tokens to generate | |
temperature: Sampling temperature | |
top_p: Top-p sampling parameter | |
Returns: | |
str: Generated response | |
""" | |
try: | |
# Prepare inputs | |
inputs = self.processor( | |
prompt, | |
image, | |
return_tensors="pt" | |
).to(DEVICE) | |
# Generate response | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True | |
) | |
# Decode and return response | |
response = self.processor.decode( | |
outputs[0], | |
skip_special_tokens=True | |
) | |
logger.debug(f"Generated response: {response[:100]}...") | |
return response | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
raise | |
def __call__(self, *args, **kwargs): | |
"""Convenience method to call generate_response.""" | |
return self.generate_response(*args, **kwargs) |