File size: 3,894 Bytes
22fc8c6
 
 
 
 
 
 
 
 
 
9799832
 
 
 
 
 
 
 
 
 
 
 
22fc8c6
 
 
 
 
 
 
 
 
 
5c03fa6
 
 
1afbcbd
 
 
 
 
 
 
 
 
 
 
5c03fa6
22fc8c6
 
 
 
 
 
 
 
 
 
1afbcbd
 
22fc8c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14bb913
22fc8c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cd40cd
5c03fa6
14bb913
22fc8c6
1cd40cd
22fc8c6
 
 
 
 
14bb913
1cd40cd
 
22fc8c6
 
 
 
 
 
 
 
 
 
 
 
 
1afbcbd
 
 
22fc8c6
1cd40cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""SpaceLlama3.1 demo gradio app."""

import datetime
import logging
import os

import gradio as gr
import torch
import PIL.Image
from prismatic import load
from huggingface_hub import login

# Authenticate with the Hugging Face Hub
def authenticate_huggingface():
    hf_token = os.getenv("HF_TOKEN")
    if hf_token:
        login(token=hf_token)
    else:
        raise ValueError("Hugging Face API token not found. Please set it as an environment variable named 'HF_TOKEN'.")

# Call the authentication function once at the start
authenticate_huggingface()

INTRO_TEXT = """SpaceLlama3.1 demo\n\n
| [Model](https://huggingface.co/remyxai/SpaceLlama3.1) 
| [GitHub](https://github.com/remyxai/VQASynth/tree/main) 
| [Demo](https://huggingface.co/spaces/remyxai/SpaceLlama3.1) 
| [Discord](https://discord.gg/DAy3P5wYJk) 
\n\n
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications.
"""

# Set model location as a constant outside the function
MODEL_LOCATION = "remyxai/SpaceLlama3.1"  # Update as needed

# Global model variable
global_model = None

def load_model():
    """Loads the model globally."""
    global global_model
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    global_model = load(MODEL_LOCATION)
    global_model.to(device, dtype=torch.bfloat16)
    logging.info("Model loaded successfully.")

def compute(image, prompt):
    """Runs model inference."""
    if image is None:
        raise gr.Error("Image required")

    logging.info('prompt="%s"', prompt)

    # Open the image file
    if isinstance(image, str):
        image = PIL.Image.open(image).convert("RGB")

    # Use the globally loaded model
    vlm = global_model

    # Prepare prompt
    prompt_builder = vlm.get_prompt_builder()
    prompt_builder.add_turn(role="human", message=prompt)
    prompt_text = prompt_builder.get_prompt()

    # Generate the text based on image and prompt
    generated_text = vlm.generate(
        image,
        prompt_text,
        do_sample=True,
        temperature=0.1,
        max_new_tokens=512,
        min_length=1,
    )
    output = generated_text.split("</s>")[0]

    logging.info('output="%s"', output)

    return output  # Ensure that output is a string

def reset():
    """Resets the input fields."""
    return "", None

def create_app():
    """Creates demo UI."""

    with gr.Blocks() as demo:
        # Main UI structure
        gr.Markdown(INTRO_TEXT)
        with gr.Row():
            image = gr.Image(value=None, label="Image", type="filepath", visible=True)  # input
            with gr.Column():
                prompt = gr.Textbox(value="", label="Prompt", visible=True)
                model_info = gr.Markdown(label="Model Info")
                run = gr.Button("Run", variant="primary")
                clear = gr.Button("Clear")
                highlighted_text = gr.HighlightedText(value="", label="Output", visible=True)

        # Button event handlers
        run.click(
            fn=compute,
            inputs=[image, prompt],
            outputs=highlighted_text,  # Ensure this is the right output component
        )
        clear.click(fn=reset, inputs=None, outputs=[prompt, image])

        # Status
        status = gr.Markdown(f"Startup: {datetime.datetime.now()}")
        gpu_kind = gr.Markdown(f"GPU=?")
        demo.load(
            fn=lambda: f"Model `{MODEL_LOCATION}` loaded.",  # Ensure the output is a string
            inputs=None,
            outputs=model_info,
        )

    return demo

if __name__ == "__main__":

    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
    )

    for k, v in os.environ.items():
        logging.info('environ["%s"] = %r', k, v)

    # Load the model once globally
    load_model()

    create_app().queue().launch()