CogVLM2-4-Doc / app.py
DoctorSlimm's picture
Update app.py
e4974a1 verified
raw
history blame
No virus
2.66 kB
import os
import torch
import spaces
import gradio as gr
from PIL import Image
from transformers.utils import move_cache
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer
# https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B
MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B"
# https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B-int4
# MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B-int4"
### DOWNLOAD ###
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
MODEL_PATH = snapshot_download(MODEL_PATH)
move_cache()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
## MODEL ##
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
## TOKENIZER ##
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True,
).to(DEVICE).eval()
text_only_template = """USER: {} ASSISTANT:"""
@spaces.GPU
def generate_caption(image, prompt):
print(DEVICE)
# Process the image and the prompt
# image = Image.open(image_path).convert('RGB')
image = image.convert('RGB')
query = "USER: %s ASSISTANT:" % prompt
input_by_model = model.build_conversation_input_ids(
tokenizer,
query=query,
history=[],
images=[image],
template_version='chat'
)
inputs = {
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[input_by_model['images'][0].to(DEVICE).to(TORCH_TYPE)]] if image is not None else None,
}
gen_kwargs = {
"max_new_tokens": 2048,
"pad_token_id": 128002,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0])
response = response.split("<|end_of_text|>")[0]
print("\nCogVLM2:", response)
return response
## make predictions via api ##
# https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app
demo = gr.Interface(
fn=generate_caption,
inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Prompt", value="Describe the image in great detail")],
outputs=gr.Textbox(label="Generated Caption")
)
# Launch the interface
demo.launch(share=True)