GLM-4-DOC / app.py
vilarin's picture
Update app.py
ac56402 verified
raw
history blame
No virus
5.09 kB
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
MODEL_LIST = ["THUDM/glm-4v-9b"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = os.environ.get("MODEL_ID")
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1>VL-Chatbox</h1>"
DESCRIPTION = f'<center><p>A SPACE FOR VLM MODELS</p><br><h3>MODEL NOW: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></center></h3>'
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h1 {
text-align: center;
display: block;
}
"""
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(0)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model.eval()
@spaces.GPU()
def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
print(f'message is - {message}')
print(f'history is - {history}')
conversation = []
if message["files"]:
image = Image.open(message["files"][-1]).convert('RGB')
conversation.append({"role": "user", "image": image, "content": message['text']})
else:
if len(history) == 0:
#raise gr.Error("Please upload an image first.")
image = None
conversation.append({"role": "user", "content": message['text']})
else:
#image = Image.open(history[0][0][0])
for prompt, answer in history:
if answer is None:
image = Image.open(prompt[0])
conversation.extend([{"role": "user", "content": ""},{"role": "assistant", "content": ""}])
else:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "image": image, "content": message['text']})
print(f"Conversation is -\n{conversation}")
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
max_length=max_length,
streamer=streamer,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=penalty,
eos_token_id=[151329, 151336, 151338],
)
gen_kwargs = {**input_ids, **generate_kwargs}
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=450)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
placeholder="Enter message or upload a file one time...",
show_label=False,
)
EXAMPLES = [
[{"text": "Describe it in detailed", "files": ["./laptop.jpg"]}],
[{"text": "Where it is?", "files": ["./hotel.jpg"]}],
[{"text": "Is it real?", "files": ["./spacecat.png"]}]
]
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=8192,
step=1,
value=1024,
label="Max Length",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=10,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="Repetition penalty",
render=False,
),
],
),
gr.Examples(EXAMPLES,[chat_input])
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False)