chatBot-kosmos2 / kosmos.py
the-future-dev's picture
first commit
b5dc2f8
raw
history blame contribute delete
No virus
1.08 kB
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
from os import path
from torchvision.transforms import ToTensor
model_id = "microsoft/kosmos-2-patch14-224"
model = AutoModelForVision2Seq.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
def single_image_classification(image, prompt="", max_new_tokens=30):
inputs = processor(text=prompt, images=image, return_tensors="pt")
generated_ids = model.generate(
pixel_values=inputs["pixel_values"],
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
image_embeds=None,
image_embeds_position_mask=inputs["image_embeds_position_mask"],
use_cache=True,
max_new_tokens=max_new_tokens,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("GENERATED:", generated_text)
processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False)
print("PROCESSED:", processed_text)
return processed_text