File size: 3,093 Bytes
b493bea
 
 
 
 
 
 
 
 
de4c8be
341d1bf
 
b493bea
 
341d1bf
ca246a4
341d1bf
 
 
 
 
 
 
 
 
 
5f53f7b
 
 
 
 
 
341d1bf
 
 
 
 
 
 
 
 
 
 
 
de4c8be
341d1bf
 
 
b493bea
341d1bf
 
b493bea
341d1bf
 
de4c8be
341d1bf
 
de4c8be
341d1bf
 
de4c8be
341d1bf
 
 
 
 
 
b493bea
341d1bf
 
 
 
b493bea
341d1bf
 
 
b493bea
341d1bf
b493bea
341d1bf
b493bea
341d1bf
 
61ce98e
341d1bf
b493bea
341d1bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Dict, List, Any
from tempfile import TemporaryDirectory
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
import torch
import requests


class EndpointHandler:
    def __init__(self):
        pass
        # self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")

    
        # device = 'gpu' if torch.cuda.is_available() else 'cpu'

        # model = LlavaNextForConditionalGeneration.from_pretrained(
        #     "llava-hf/llava-v1.6-mistral-7b-hf", 
        #     torch_dtype=torch.float32 if device == 'cpu' else torch.float16, 
        #     low_cpu_mem_usage=True
        # )
        # model.to(device)

        # self.model = model
        # self.device = device

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        inputs = data.get("inputs", "")
        if not inputs:
            return [{"error": "No inputs provided"}]
        
        return inputs
    # def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
    #     """
    #     data args:
    #         inputs (:obj: `dict`)
    #     Return:
    #         A :obj:`list` | `dict`: will be serialized and returned
    #     """    
    #     # get inputs
    #     inputs = data.get("inputs")

    #     if not inputs:
    #         return f"Inputs not in payload got {data}"
        
    #     # get additional date field0
    #     prompt = inputs.get("prompt")
    #     image_url = inputs.get("image")

    #     if image_url is None:
    #         return "You need to upload an image URL for LLaVA to work."

    #     if prompt is None:
    #         prompt = "Can you describe this picture focusing on specifics visual artifacts and ambiance (objects, colors, person, athmosphere..). Please stay concise only output keywords and concepts detected."
        
    #     if not self.model:
    #         return "Model was not initialized"
        
    #     if not self.processor:
    #         return "Processor was not initialized"
        
    #     # Create a temporary directory
    #     with TemporaryDirectory() as tmpdirname:
    #         # Download the image
    #         response = requests.get(image_url)
    #         if response.status_code != 200:
    #             return "Failed to download the image."

    #         # Define the path for the downloaded image
    #         image_path = f"{tmpdirname}/image.jpg"
    #         with open(image_path, "wb") as f:
    #             f.write(response.content)

    #         # Open the downloaded image
    #         with Image.open(image_path).convert("RGB") as image:
    #             prompt = f"[INST] <image>\n{prompt} [/INST]"

    #             inputs = self.processor(prompt, image, return_tensors="pt").to(self.device)

    #             output = self.model.generate(**inputs, max_new_tokens=100)

    #             if not output:
    #                 return 'Model failed to generate'

    #             clean = self.processor.decode(output[0], skip_special_tokens=True)

    #             return clean