# A100 Zero GPU import spaces # TroL Package import torch from PIL import Image from utils.utils import * import torch.nn.functional as F from trol.load_trol import load_trol from torchvision.transforms.functional import pil_to_tensor # Gradio Package import time import gradio as gr from threading import Thread from accelerate import Accelerator from transformers import TextIteratorStreamer from torchvision.transforms.functional import pil_to_tensor # flash attention import subprocess subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) # accel accel = Accelerator() # User prompt prompt_type="with_image" # Select one option "text_only", "with_image" img_path='figures/demo.png' question="What is the troll doing? Provide the detail in the image and imagine what the event happens." # loading model model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B') # loading model model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B') # loading model model_7, tokenizer_7 = load_trol(link='TroL-7B') def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p): # propagation _inputs = model.eval_process(inputs=inputs, data='demo', tokenizer=tokenizer, device=device, img_token_number=image_token_number) generation_kwargs = _inputs generation_kwargs.update({'streamer': streamer}) generation_kwargs.update({'do_sample': True}) generation_kwargs.update({'max_new_tokens': new_max_token}) generation_kwargs.update({'top_p': top_p}) generation_kwargs.update({'temperature': temperature}) generation_kwargs.update({'use_cache': True}) return model.generate(**generation_kwargs) @spaces.GPU def bot_streaming(message, history, link, temperature, new_max_token, top_p): # model selection if "1.8B" in link: model = model_1_8 tokenizer = tokenizer_1_8 path = "BK-Lee/TroL-1.8B" elif "3.8B" in link: model = model_3_8 tokenizer = tokenizer_3_8 path = "BK-Lee/TroL-3.8B" elif "7B" in link: model = model_7 tokenizer = tokenizer_7 path = "BK-Lee/TroL-7B" # trol gating load from huggingface_hub import hf_hub_download try: model.model.initialize_trol_gating() model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt"))) except: model.language_model.model.initialize_trol_gating() model.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt"))) # X -> float16 conversion for param in model.parameters(): if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower(): param.data = param.data.to(torch.float16) # cpu -> gpu for param in model.parameters(): if not param.is_cuda: param.data = param.to(accel.device) try: # prompt type -> input prompt image_token_number = None if len(message['files']) == 1: # Image Load image = pil_to_tensor(Image.open(message['files'][0]).convert("RGB")) if "3.8B" not in link: image_token_number = 1225 image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0) inputs = [{'image': image.to(accel.device), 'question': message['text']}] elif len(message['files']) > 1: raise Exception("No way!") else: inputs = [{'question': message['text']}] # Text Generation with torch.inference_mode(): # kwargs streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) # Threading generation thread = Thread(target=threading_function, kwargs=dict(inputs=inputs, image_token_number=image_token_number, streamer=streamer, model=model, tokenizer=tokenizer, device=accel.device, temperature=temperature, new_max_token=new_max_token, top_p=top_p)) thread.start() # generated text generated_text = "" for new_text in streamer: generated_text += new_text generated_text # Text decoding response = output_filtering(generated_text, model) except: response = "There may be unsupported format: ex) pdf, video, sound. Only supported is a single image in this version." # private log print text = message['text'] files = message['files'] print('-----------------------------') print(f'Link: {link}') print(f'Text: {text}') print(f'MM Files: {files}') print(f'Response: {response}') print('-----------------------------\n') buffer = "" for character in response: buffer += character time.sleep(0.012) yield buffer demo = gr.ChatInterface(fn=bot_streaming, additional_inputs = [gr.Radio(["1.8B", "3.8B", "7B"], label="Size", info="Select one model size", value="7B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")], additional_inputs_accordion="Generation Hyperparameters", theme=gr.themes.Soft(), title="TroL", description="TroL is efficient 1.8B, 3.8B, and 7B size Large Language and Vision Models built on new propagation strategy. " "Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity) " "Note that, we don't support history-based conversation referring to previous dialogue", stop_btn="Stop Generation", multimodal=True) demo.launch()