from __future__ import annotations import spaces import gradio as gr from threading import Thread from transformers import TextIteratorStreamer import hashlib import os from transformers import AutoModel, AutoProcessor import torch import sys import subprocess from PIL import Image from cobra import load import time subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'packaging']) subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'ninja']) subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'mamba-ssm']) subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'causal-conv1d']) vlm = load("cobra+3b") if torch.cuda.is_available(): DEVICE = "cuda" DTYPE = torch.bfloat16 else: DEVICE = "cpu" DTYPE = torch.float32 vlm.to(DEVICE, dtype=DTYPE) prompt_builder = vlm.get_prompt_builder() @spaces.GPU def bot_streaming(message, history, temperature, top_k, max_new_tokens): if len(history) == 0: prompt_builder.prompt, prompt_builder.turn_count = "", 0 print(message) if message["files"]: image = message["files"][-1]["path"] else: # if there's no image uploaded for this turn, look for images in the past turns # kept inside tuples, take the last one for hist in history: if type(hist[0])==tuple: image = hist[0][0] image = Image.open(image).convert("RGB") prompt_builder.add_turn(role="human", message=message['text']) prompt_text = prompt_builder.get_prompt() # Generate from the VLM with torch.no_grad(): generated_text = vlm.generate( image, prompt_text, cg=True, do_sample=True, temperature=temperature, top_k=top_k, max_new_tokens=max_new_tokens, ) prompt_builder.add_turn(role="gpt", message=generated_text) time.sleep(0.04) yield generated_text demo = gr.ChatInterface(fn=bot_streaming, additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"), gr.Slider(1, 3, value=1, step=1, label="Top k"), gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")], title="Cobra", description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it.", stop_btn="Stop Generation", multimodal=True, examples=[{"text": "Describe this image", "files":["./cobra.png"]}]) demo.launch(debug=True)