Cristiants's picture
Update app.py
8f128da
raw
history blame
No virus
1.33 kB
import requests
from PIL import Image
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
import gradio as gr
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# def predict(inp):
# inp = transforms.ToTensor()(inp).unsqueeze(0)
# with torch.no_grad():
# prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
# confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
# return confidences
# demo = gr.Interface(fn=predict,
# inputs=gr.inputs.Image(type="pil"),
# outputs=gr.outputs.Label(num_top_classes=3)
# )
def predict(imageurl):
image = Image.open(requests.get(imageurl, stream=True).raw).convert('RGB')
inputs = processor(image, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return('caption: '+generated_text)
demo = gr.Interface(fn=predict,
inputs="text",
outputs=gr.outputs.Label(num_top_classes=3)
)
demo.launch()