Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,894 Bytes
22fc8c6 9799832 22fc8c6 5c03fa6 1afbcbd 5c03fa6 22fc8c6 1afbcbd 22fc8c6 14bb913 22fc8c6 1cd40cd 5c03fa6 14bb913 22fc8c6 1cd40cd 22fc8c6 14bb913 1cd40cd 22fc8c6 1afbcbd 22fc8c6 1cd40cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""SpaceLlama3.1 demo gradio app."""
import datetime
import logging
import os
import gradio as gr
import torch
import PIL.Image
from prismatic import load
from huggingface_hub import login
# Authenticate with the Hugging Face Hub
def authenticate_huggingface():
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
raise ValueError("Hugging Face API token not found. Please set it as an environment variable named 'HF_TOKEN'.")
# Call the authentication function once at the start
authenticate_huggingface()
INTRO_TEXT = """SpaceLlama3.1 demo\n\n
| [Model](https://huggingface.co/remyxai/SpaceLlama3.1)
| [GitHub](https://github.com/remyxai/VQASynth/tree/main)
| [Demo](https://huggingface.co/spaces/remyxai/SpaceLlama3.1)
| [Discord](https://discord.gg/DAy3P5wYJk)
\n\n
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications.
"""
# Set model location as a constant outside the function
MODEL_LOCATION = "remyxai/SpaceLlama3.1" # Update as needed
# Global model variable
global_model = None
def load_model():
"""Loads the model globally."""
global global_model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
global_model = load(MODEL_LOCATION)
global_model.to(device, dtype=torch.bfloat16)
logging.info("Model loaded successfully.")
def compute(image, prompt):
"""Runs model inference."""
if image is None:
raise gr.Error("Image required")
logging.info('prompt="%s"', prompt)
# Open the image file
if isinstance(image, str):
image = PIL.Image.open(image).convert("RGB")
# Use the globally loaded model
vlm = global_model
# Prepare prompt
prompt_builder = vlm.get_prompt_builder()
prompt_builder.add_turn(role="human", message=prompt)
prompt_text = prompt_builder.get_prompt()
# Generate the text based on image and prompt
generated_text = vlm.generate(
image,
prompt_text,
do_sample=True,
temperature=0.1,
max_new_tokens=512,
min_length=1,
)
output = generated_text.split("</s>")[0]
logging.info('output="%s"', output)
return output # Ensure that output is a string
def reset():
"""Resets the input fields."""
return "", None
def create_app():
"""Creates demo UI."""
with gr.Blocks() as demo:
# Main UI structure
gr.Markdown(INTRO_TEXT)
with gr.Row():
image = gr.Image(value=None, label="Image", type="filepath", visible=True) # input
with gr.Column():
prompt = gr.Textbox(value="", label="Prompt", visible=True)
model_info = gr.Markdown(label="Model Info")
run = gr.Button("Run", variant="primary")
clear = gr.Button("Clear")
highlighted_text = gr.HighlightedText(value="", label="Output", visible=True)
# Button event handlers
run.click(
fn=compute,
inputs=[image, prompt],
outputs=highlighted_text, # Ensure this is the right output component
)
clear.click(fn=reset, inputs=None, outputs=[prompt, image])
# Status
status = gr.Markdown(f"Startup: {datetime.datetime.now()}")
gpu_kind = gr.Markdown(f"GPU=?")
demo.load(
fn=lambda: f"Model `{MODEL_LOCATION}` loaded.", # Ensure the output is a string
inputs=None,
outputs=model_info,
)
return demo
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
for k, v in os.environ.items():
logging.info('environ["%s"] = %r', k, v)
# Load the model once globally
load_model()
create_app().queue().launch()
|