File size: 3,394 Bytes
e5d40e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70a8a19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5d40e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
LLaVA model implementation.
"""

import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image

from ..configs.settings import MODEL_NAME, MODEL_REVISION, DEVICE
from ..utils.logging import get_logger

logger = get_logger(__name__)

class LLaVAModel:
    """LLaVA model wrapper class."""
    
    def __init__(self):
        """Initialize the LLaVA model and processor."""
        try:
            logger.info(f"Initializing LLaVA model from {MODEL_NAME}")
            logger.info(f"Using device: {DEVICE}")
            
            # Initialize processor
            self.processor = AutoProcessor.from_pretrained(
                MODEL_NAME,
                revision=MODEL_REVISION,
                trust_remote_code=True
            )
            
            # Set model dtype based on device
            model_dtype = torch.float32 if DEVICE == "cpu" else torch.float16
            
            # Initialize model with appropriate settings
            self.model = AutoModelForCausalLM.from_pretrained(
                MODEL_NAME,
                revision=MODEL_REVISION,
                torch_dtype=model_dtype,
                device_map="auto" if DEVICE == "cuda" else None,
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )
            
            # Move model to device if not using device_map
            if DEVICE == "cpu":
                self.model = self.model.to(DEVICE)
            
            logger.info("Model initialization complete")
            
        except Exception as e:
            logger.error(f"Error initializing model: {str(e)}")
            raise
    
    def generate_response(
        self,
        image: Image.Image,
        prompt: str,
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9
    ) -> str:
        """
        Generate a response for the given image and prompt.
        
        Args:
            image: Input image as PIL Image
            prompt: Text prompt for the model
            max_new_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature
            top_p: Top-p sampling parameter
            
        Returns:
            str: Generated response
        """
        try:
            # Prepare inputs
            inputs = self.processor(
                prompt,
                image,
                return_tensors="pt"
            ).to(DEVICE)
            
            # Generate response
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=True
                )
            
            # Decode and return response
            response = self.processor.decode(
                outputs[0],
                skip_special_tokens=True
            )
            
            logger.debug(f"Generated response: {response[:100]}...")
            return response
            
        except Exception as e:
            logger.error(f"Error generating response: {str(e)}")
            raise
    
    def __call__(self, *args, **kwargs):
        """Convenience method to call generate_response."""
        return self.generate_response(*args, **kwargs)