GLM-4-DOC / app.py
vilarin's picture
Update app.py
3f5fb82 verified
raw
history blame
No virus
4.16 kB
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
MODEL_LIST = ["THUDM/glm-4v-9b"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = os.environ.get("MODEL_ID")
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1><center>VL-Chatbox</center></h1>"
DESCRIPTION = f'<h3><center>MODEL: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></center></h3>'
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(0)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model.eval()
@spaces.GPU()
def stream_chat(message, history: list, temperature: float, max_length: int):
print(f'message is - {message}')
print(f'history is - {history}')
conversation = []
if message["files"]:
image = Image.open(message["files"][-1]).convert('RGB')
conversation.append({"role": "user", "image": image, "content": message['text']})
else:
if len(history) == 0:
raise gr.Error("Please upload an image first.")
image = None
else:
image = Image.open(history[0][1])
for prompt, answer in history:
if answer is None:
conversation.extend([{"role": "user", "content": prompt},{"role": "assistant", "content": ""}])
else:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "image": image, "content": message['text']})
print(f"Conversation is -\n{conversation}")
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
max_length=max_length,
streamer=streamer,
do_sample=True,
top_k=1,
temperature=temperature,
repetition_penalty=1.2,
)
gen_kwargs = {**input_ids, **generate_kwargs}
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=450)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
placeholder="Enter message or upload file...",
show_label=False,
)
EXAMPLES = [
[{"text": "Describe it in great detailed.", "files": ["./laptop.jpg"]}],
[{"text": "Describe it in great detailed.", "files": ["./hotel.jpg"]}],
[{"text": "Describe it in great detailed.", "files": ["./spacecat.png"]}]
]
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max Length",
render=False,
),
],
),
gr.Examples(EXAMPLES,[chat_input])
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False)