CogVLM2-4-Doc / app.py
DoctorSlimm's picture
Update app.py
e4974a1 verified
raw
history blame contribute delete
No virus
2.66 kB
import os
import torch
import spaces
import gradio as gr
from PIL import Image
from transformers.utils import move_cache
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer
# https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B
MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B"
# https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B-int4
# MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B-int4"
### DOWNLOAD ###
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
MODEL_PATH = snapshot_download(MODEL_PATH)
move_cache()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
## MODEL ##
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
## TOKENIZER ##
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True,
).to(DEVICE).eval()
text_only_template = """USER: {} ASSISTANT:"""
@spaces.GPU
def generate_caption(image, prompt):
print(DEVICE)
# Process the image and the prompt
# image = Image.open(image_path).convert('RGB')
image = image.convert('RGB')
query = "USER: %s ASSISTANT:" % prompt
input_by_model = model.build_conversation_input_ids(
tokenizer,
query=query,
history=[],
images=[image],
template_version='chat'
)
inputs = {
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[input_by_model['images'][0].to(DEVICE).to(TORCH_TYPE)]] if image is not None else None,
}
gen_kwargs = {
"max_new_tokens": 2048,
"pad_token_id": 128002,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0])
response = response.split("<|end_of_text|>")[0]
print("\nCogVLM2:", response)
return response
## make predictions via api ##
# https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app
demo = gr.Interface(
fn=generate_caption,
inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Prompt", value="Describe the image in great detail")],
outputs=gr.Textbox(label="Generated Caption")
)
# Launch the interface
demo.launch(share=True)