No support for pipline or TextStreamer

#34
by mohamedlotfy50 - opened

I tried to run the model with pipeline and found that it does not support image-to-text tasks, I also tried to use the TextStreamer but it raised an exception.

mohamedlotfy50 changed discussion title from No support for pipline or stream to No support for pipline or TextStreamer

I'm using TextStreamer just fine.
I will share my code for you

from transformers import TextStreamer
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
import requests
from io import BytesIO
import argparse


DEFAULT_IMAGE_TOKEN = "<|image_1|>"

def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image

def main(args):

    model_id = args.model_base
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map=args.device, trust_remote_code=True, torch_dtype=torch.float16)

    if args.model_path:
        peft_model_id = args.model_path
        model.load_adapter(peft_model_id)

    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

    messages = [
    {"role": "system", "content": "You are a helpful AI assistant. Provide a useful information about the given image."},
    ]

    image = load_image(args.image_file)

    generation_args = {
        "max_new_tokens": args.max_new_tokens,
        "temperature": args.temperature,
        "do_sample": True if args.temperature > 0 else False,
        "repetition_penalty": args.repetition_penalty,
    }

    while True:
        try:
            inp = input(f"User: ")
        except EOFError:
            inp = ""
        if not inp:
            print("exit...")
            break

        print(f"Assistant: ", end="")

        if image is not None and len(messages) < 2:
            # only putting the image token in the first turn of user.
            # You could just uncomment the system messages or use it.
            inp = DEFAULT_IMAGE_TOKEN + '\n' + inp

        messages.append({"role": "user", "content": inp})

        prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = processor(prompt, image, return_tensors="pt").to(args.device)

        streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)

        with torch.inference_mode():
            generate_ids = model.generate(
                **inputs, 
                eos_token_id=processor.tokenizer.eos_token_id, 
                streamer=streamer,
                **generation_args,
                use_cache=True
            )

        outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        messages.append({"role":"assistant", "content": outputs})

        if args.debug:
            print("\n", {"prompt": prompt, "outputs": outputs}, "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default=None)
    parser.add_argument("--model-base", type=str, default="microsoft/Phi-3-vision-128k-instruct")
    parser.add_argument("--image-file", type=str, required=True)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--temperature", type=float, default=0)
    parser.add_argument("--repetition-penalty", type=float, default=1.0)
    parser.add_argument("--max-new-tokens", type=int, default=500)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()
    main(args)

Thanks a lot, I think my problem was I forgot to set the skip_special_tokens=True

Sign up or log in to comment