Spaces:
Runtime error
Runtime error
| import torch | |
| import transformers | |
| import warnings | |
| import time | |
| import spaces | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from PIL import Image | |
| from threading import Thread | |
| transformers.logging.set_verbosity_error() | |
| transformers.logging.disable_progress_bar() | |
| warnings.filterwarnings("ignore") | |
| device = "cuda" # or cpu | |
| torch.set_default_device(device) | |
| model_name = "BAAI/Bunny-v1_1-Llama-3-8B-V" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, # float32 for cpu | |
| device_map="auto", | |
| trust_remote_code=True) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True) | |
| def bot_streaming(message, history): | |
| print(message) | |
| if message["files"]: | |
| # message["files"][-1] is a Dict or just a string | |
| if type(message["files"][-1]) == dict: | |
| image_file = message["files"][-1]["path"] | |
| else: | |
| image_file = message["files"][-1] | |
| else: | |
| image_file = None | |
| # if there's no image uploaded for this turn, look for images in the past turns | |
| # kept inside tuples, take the last one | |
| for hist in history: | |
| if type(hist[0]) == tuple: | |
| image_file = hist[0][0] | |
| prompt = message["text"] | |
| if image_file is None: | |
| text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {prompt} ASSISTANT:" | |
| input_ids = torch.tensor(tokenizer(text).input_ids, dtype=torch.long).unsqueeze(0).to(device) | |
| else: | |
| text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:" | |
| text_chunks = [tokenizer(chunk).input_ids for chunk in text.split("<image>")] | |
| input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0).to(device) | |
| if image_file is not None: | |
| image = Image.open(image_file) | |
| image_tensor = model.process_images([image], model.config).to(dtype=model.dtype, device=device) | |
| else: | |
| image_tensor = None | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) | |
| thread = Thread(target=model.generate, kwargs=dict( | |
| inputs=input_ids, | |
| images=image_tensor, | |
| do_sample=True, | |
| temperature=0.2, | |
| top_p=0.7, | |
| max_new_tokens=512, | |
| streamer=streamer, | |
| use_cache=True, | |
| repetition_penalty=1.08 | |
| )) | |
| thread.start() | |
| buffer = "" | |
| time.sleep(0.5) | |
| for new_text in streamer: | |
| if "<|end_of_text|>" in new_text: | |
| new_text = new_text.split("<|end_of_text|>")[0] | |
| buffer += new_text | |
| # generated_text_without_prompt = buffer[len(text_prompt):] | |
| generated_text_without_prompt = buffer | |
| # print(generated_text_without_prompt) | |
| time.sleep(0.06) | |
| # print(f"new_text: {generated_text_without_prompt}") | |
| yield generated_text_without_prompt | |
| title_markdown = (""" | |
| # π° Bunny: A family of lightweight multimodal models | |
| [π [Technical report](https://arxiv.org/abs/2402.11530)] | [π [Code](https://github.com/BAAI-DCAI/Bunny)] | [π€ [Bunny-v1.1-Llama-3-8B-V](https://huggingface.co/BAAI/Bunny-v1_1-Llama-3-8B-V)] | [π€ [Bunny-v1.1-4B](https://huggingface.co/BAAI/Bunny-v1_1-4B)] | [π€ [Bunny-v1.0-3B](https://huggingface.co/BAAI/Bunny-v1_0-3B)] | |
| """) | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| label="Bunny-v1.1-Llama-3-8B-V", | |
| avatar_images=[f"./assets/user.png", f"./assets/icon.jpg"], | |
| height=550 | |
| ) | |
| chat_input = gr.MultimodalTextbox( | |
| interactive=True, | |
| file_types=["image"], | |
| placeholder="Enter message or upload file...", | |
| show_label=False | |
| ) | |
| with gr.Blocks(fill_height=True) as demo: | |
| gr.Markdown(title_markdown) | |
| gr.ChatInterface( | |
| fn=bot_streaming, | |
| stop_btn="Stop Generation", | |
| multimodal=True, | |
| textbox=chat_input, | |
| chatbot=chatbot | |
| ) | |
| gr.Examples(examples=[{"text": "What is the astronaut holding in his hand?", "files": ["./assets/example_1.png"]}, | |
| {"text": "Why is the image funny?", "files": ["./assets/example_2.png"]}], inputs=chat_input) | |
| demo.queue(api_open=False) | |
| demo.launch(show_api=False, share=False) |