Cristiants's picture
Update app.py
8f128da
raw
history blame contribute delete
No virus
1.33 kB
import requests
from PIL import Image
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
import gradio as gr
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# def predict(inp):
# inp = transforms.ToTensor()(inp).unsqueeze(0)
# with torch.no_grad():
# prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
# confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
# return confidences
# demo = gr.Interface(fn=predict,
# inputs=gr.inputs.Image(type="pil"),
# outputs=gr.outputs.Label(num_top_classes=3)
# )
def predict(imageurl):
image = Image.open(requests.get(imageurl, stream=True).raw).convert('RGB')
inputs = processor(image, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return('caption: '+generated_text)
demo = gr.Interface(fn=predict,
inputs="text",
outputs=gr.outputs.Label(num_top_classes=3)
)
demo.launch()