chameleon / app.py
darknoon's picture
almost
94487fd
raw
history blame contribute delete
No virus
4.02 kB
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, ChameleonProcessor, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
from PIL import Image
import requests
model_path = "facebook/chameleon-7b"
# model = ChameleonForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
# processor = ChameleonProcessor.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
model.eval()
processor = ChameleonProcessor.from_pretrained(model_path)
tokenizer = processor.tokenizer
# file_name, alt
multimodal_file = tuple[str, str]
# {'text': 'message here', 'files': []}
multimodal_message = list[str | multimodal_file] | multimodal_file
# todo: verify this type with gr.ChatInterface
message_t = dict[str, str | list[multimodal_file]]
history_t = list[tuple[str, str] | list[tuple[multimodal_message, multimodal_message]]]
def history_to_prompt(
message,
history: history_t,
eot_id = "<reserved08706>",
image_placeholder = "<image>"
):
prompt = message["text"]
images = [Image.open(f) for f in message["files"]]
for turn in history:
print("turn:", turn)
# turn should be a tuple of user message and assistant message
for message in turn:
if isinstance(message, str):
prompt += user_message
prompt += eot_id
if isinstance(message, list):
for item in message:
if isinstance(item, str):
prompt += item
elif isinstance(item, tuple):
image_path, alt = item
prompt += image_placeholder
image = Image.open(requests.get(image_path, stream=True).raw)
images.append(image)
else:
prompt += f"(unhandled message type: {message})"
prompt += eot_id
return prompt, images
@spaces.GPU(duration=30)
def respond(
message,
history: history_t,
system_message,
max_tokens,
temperature,
top_p,
):
response = ""
print(f"message: {message}\nhistory:\n\n{history}\n")
prompt, images = history_to_prompt(message, history)
print(f"prompt:\n\n{prompt}\n")
# prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
# image = Image.open(requests.get("https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True).raw)
# images = [image]
inputs = processor(prompt, images=images, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
try:
# launch generation in the background
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_message = ""
for new_token in streamer:
partial_message += new_token
yield partial_message
except e:
return f"Error: {e}"
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
multimodal=True,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch(debug=True)