File size: 5,362 Bytes
dd4cd4b 91911bd dd4cd4b 91911bd dd4cd4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# A100 Zero GPU
import spaces
# flash attention
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Phantom Package
import torch
from PIL import Image
from utils.utils import *
from model.load_model import load_model
# Gradio Package
import time
import gradio as gr
from threading import Thread
from accelerate import Accelerator
from transformers import TextIteratorStreamer
from torchvision.transforms.functional import pil_to_tensor
# accel
accel = Accelerator()
# loading model
model_1_8, tokenizer_1_8 = load_model(size='1.8b')
# loading model
model_3_8, tokenizer_3_8 = load_model(size='3.8b')
# loading model
model_7, tokenizer_7 = load_model(size='7b')
def threading_function(inputs, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
# propagation
_inputs = model.eval_process(inputs=inputs,
data='demo',
tokenizer=tokenizer,
device=device)
generation_kwargs = _inputs
generation_kwargs.update({'streamer': streamer})
generation_kwargs.update({'do_sample': True})
generation_kwargs.update({'max_new_tokens': new_max_token})
generation_kwargs.update({'top_p': top_p})
generation_kwargs.update({'temperature': temperature})
generation_kwargs.update({'use_cache': True})
return model.generate(**generation_kwargs)
@spaces.GPU
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
# model selection
if "1.8B" in link:
model = model_1_8
tokenizer = tokenizer_1_8
elif "3.8B" in link:
model = model_3_8
tokenizer = tokenizer_3_8
elif "7B" in link:
model = model_7
tokenizer = tokenizer_7
# X -> bfloat16 conversion
for param in model.parameters():
if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
param.data = param.data.to(torch.bfloat16)
# cpu -> gpu
for param in model.parameters():
if not param.is_cuda:
param.data = param.to(accel.device)
try:
# prompt type -> input prompt
if len(message['files']) == 1:
# Image Load
image = pil_to_tensor(Image.open(message['files'][0]).convert("RGB"))
inputs = [{'image': image.to(accel.device), 'question': message['text']}]
elif len(message['files']) > 1:
raise Exception("No way!")
else:
inputs = [{'question': message['text']}]
# Text Generation
with torch.inference_mode():
# kwargs
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
# Threading generation
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
streamer=streamer,
model=model,
tokenizer=tokenizer,
device=accel.device,
temperature=temperature,
new_max_token=new_max_token,
top_p=top_p))
thread.start()
# generated text
generated_text = ""
for new_text in streamer:
generated_text += new_text
generated_text
# Text decoding
response = output_filtering(generated_text, model)
except:
response = "There may be unsupported format: ex) pdf, video, sound. Only supported is a single image in this version."
# private log print
text = message['text']
files = message['files']
print('-----------------------------')
print(f'Link: {link}')
print(f'Text: {text}')
print(f'MM Files: {files}')
print(f'Response: {response}')
print('-----------------------------\n')
buffer = ""
for character in response:
buffer += character
time.sleep(0.012)
yield buffer
demo = gr.ChatInterface(fn=bot_streaming,
additional_inputs = [gr.Radio(["1.8B", "3.8B", "7B"], label="Size", info="Select one model size", value="7B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
additional_inputs_accordion="Generation Hyperparameters",
theme=gr.themes.Soft(),
title="Phantom",
description="Phantom is super efficient 0.5B, 1.8B, 3.8B, and 7B size Large Language and Vision Models built on new propagation strategy. "
"Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity) "
"Note that, we don't support history-based conversation referring to previous dialogue",
stop_btn="Stop Generation", multimodal=True)
demo.launch() |