Spaces:
Sleeping
Sleeping
File size: 3,394 Bytes
e5d40e3 70a8a19 e5d40e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
"""
LLaVA model implementation.
"""
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
from ..configs.settings import MODEL_NAME, MODEL_REVISION, DEVICE
from ..utils.logging import get_logger
logger = get_logger(__name__)
class LLaVAModel:
"""LLaVA model wrapper class."""
def __init__(self):
"""Initialize the LLaVA model and processor."""
try:
logger.info(f"Initializing LLaVA model from {MODEL_NAME}")
logger.info(f"Using device: {DEVICE}")
# Initialize processor
self.processor = AutoProcessor.from_pretrained(
MODEL_NAME,
revision=MODEL_REVISION,
trust_remote_code=True
)
# Set model dtype based on device
model_dtype = torch.float32 if DEVICE == "cpu" else torch.float16
# Initialize model with appropriate settings
self.model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
revision=MODEL_REVISION,
torch_dtype=model_dtype,
device_map="auto" if DEVICE == "cuda" else None,
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Move model to device if not using device_map
if DEVICE == "cpu":
self.model = self.model.to(DEVICE)
logger.info("Model initialization complete")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
def generate_response(
self,
image: Image.Image,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9
) -> str:
"""
Generate a response for the given image and prompt.
Args:
image: Input image as PIL Image
prompt: Text prompt for the model
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling parameter
Returns:
str: Generated response
"""
try:
# Prepare inputs
inputs = self.processor(
prompt,
image,
return_tensors="pt"
).to(DEVICE)
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
# Decode and return response
response = self.processor.decode(
outputs[0],
skip_special_tokens=True
)
logger.debug(f"Generated response: {response[:100]}...")
return response
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
raise
def __call__(self, *args, **kwargs):
"""Convenience method to call generate_response."""
return self.generate_response(*args, **kwargs) |