A working demo.py for your reference

#25
by Colderthanice - opened

As the repository didn't include an installation guide or executable script, I have developed a demo.py that can work in a cuda environment. Just place it in the git clone folder and run it.

Note:

  1. It needs transformers==4.35.0.dev0, which is not yet available with pip install as of today. To install 4.35:
    git clone https://github.com/huggingface/transformers.git
    cd transformers
    pip install .

  2. It will download the fuyu-8b weights even when you already have them in the folder. The downloaded weights are stored at:
    ~/.cache/huggingface/hub/models--adept--fuyu-8b

  3. Solved OOM with float16 or bfloat16 option.

  4. Solved attention_mask warning.

  5. Choices of image files, choices from a question list, and multi-round questions.

================================================
#Script prepared by novice Green Guo @Beijing with the help of GPT-4.

import os
from transformers import FuyuProcessor, FuyuForCausalLM
from PIL import Image
import torch

def list_files_in_directory(path, extensions=[".png", ".jpeg", ".jpg", ".JPG", ".PNG", ".JPEG"]):
files = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) and any(f.endswith(ext) for ext in extensions)]
return files

def main():
# load model and processor
model_id = "adept/fuyu-8b"
processor = FuyuProcessor.from_pretrained(model_id)
model = FuyuForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype=torch.float16) # To solve OOM, float16 enables operation with only 24GB of VRAM. Alternatively float16 can be replaced with bfloat16 with differences in loading time and inference time.

# Load last image path or ask user
try:
    with open("last_path.txt", "r") as f:
        last_path = f.read().strip()
    user_input = input(f"Do you want to use the last path '{last_path}'? (yes/no, default yes): ")
    if not user_input or user_input.lower() != 'no':
        last_path = last_path
    else:
        raise ValueError("User chose to input a new path.")
except:
    last_path = input("Please provide the image directory path: ")
    with open("last_path.txt", "w") as f:
        f.write(last_path)

while True:
    # List the first 10 images in the directory
    images = list_files_in_directory(last_path)[:10]
    for idx, image in enumerate(images, start=1):
        print(f"{idx}. {image}")

    # Allow the user to select an image
    image_choice = input(f"Choose an image (1-{len(images)}) or enter its name: ")
    try:
        idx = int(image_choice)
        image_path = os.path.join(last_path, images[idx-1])
    except ValueError:
        image_path = os.path.join(last_path, image_choice)

    try:
        image = Image.open(image_path)
    except:
        print("Cannot open the image. Please check the path and try again.")
        continue

    questions = [
        "Generate a coco-style caption.",
        "What color is the object?",
        "Describe the scene.",
        "Describe the facial expression of the character.",
        "Tell me about the story from the image.",
        "Enter your own question"
    ]

    # Asking the user to select a question from list, or select to input one
    for idx, q in enumerate(questions, start=1):
        print(f"{idx}. {q}")

    q_choice = int(input("Choose a question or enter your own: "))
    if q_choice <= 5:
        text_prompt = questions[q_choice-1] + '\n'
    else:
        text_prompt = input("Please enter your question: ") + '\n'

    while True: # To enable the user to ask further question about an image
        inputs = processor(text=text_prompt, images=image, return_tensors="pt")
        for k, v in inputs.items():
            inputs[k] = v.to("cuda:0")
        # To eliminate attention_mask warning
        inputs["attention_mask"] = torch.ones(inputs["input_ids"].shape, device="cuda:0")

        generation_output = model.generate(**inputs, max_new_tokens=50, pad_token_id=model.config.eos_token_id)
        generation_text = processor.batch_decode(generation_output[:, -50:], skip_special_tokens=True)
        print("Answer:", generation_text[0])

        text_prompt = input("Ask another question about the same image or type '/exit' to exit: ") + '\n'
        if text_prompt.strip() == '/exit':
            break

if name == "main":
main()

Hey, the attention_mask will be updated in the PR https://github.com/huggingface/transformers/pull/27007 which will add batching, you'll be able to cache a few QAs directly!

Sign up or log in to comment