Spaces:
Running
Running
import argparse | |
from queue import Queue | |
from threading import Thread | |
import torch | |
from PIL import Image | |
from transformers import AutoTokenizer, TextIteratorStreamer | |
from moondream.hf import LATEST_REVISION, Moondream, detect_device | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--image", type=str, required=True) | |
parser.add_argument("--prompt", type=str, required=False) | |
parser.add_argument("--caption", action="store_true") | |
parser.add_argument("--cpu", action="store_true") | |
args = parser.parse_args() | |
if args.cpu: | |
device = torch.device("cpu") | |
dtype = torch.float32 | |
else: | |
device, dtype = detect_device() | |
if device != torch.device("cpu"): | |
print("Using device:", device) | |
print("If you run into issues, pass the `--cpu` flag to this script.") | |
print() | |
image_path = args.image | |
prompt = args.prompt | |
model_id = "vikhyatk/moondream2" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION) | |
moondream = Moondream.from_pretrained( | |
model_id, | |
revision=LATEST_REVISION, | |
torch_dtype=dtype, | |
).to(device=device) | |
moondream.eval() | |
image = Image.open(image_path) | |
if args.caption: | |
print(moondream.caption(images=[image], tokenizer=tokenizer)[0]) | |
else: | |
image_embeds = moondream.encode_image(image) | |
if prompt is None: | |
chat_history = "" | |
while True: | |
question = input("> ") | |
result_queue = Queue() | |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
# Separate direct arguments from keyword arguments | |
thread_args = (image_embeds, question, tokenizer, chat_history) | |
thread_kwargs = {"streamer": streamer, "result_queue": result_queue} | |
thread = Thread( | |
target=moondream.answer_question, | |
args=thread_args, | |
kwargs=thread_kwargs, | |
) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
if not new_text.endswith("<") and not new_text.endswith("END"): | |
print(buffer, end="", flush=True) | |
buffer = "" | |
print(buffer) | |
thread.join() | |
answer = result_queue.get() | |
chat_history += f"Question: {question}\n\nAnswer: {answer}\n\n" | |
else: | |
print(">", prompt) | |
answer = moondream.answer_question(image_embeds, prompt, tokenizer) | |
print(answer) | |