autonomous019's picture
Update app.py
fd787f5
from PIL import Image
import requests
import matplotlib.pyplot as plt
import gradio as gr
from gradio.mix import Parallel
import torch
from transformers import (
ViTConfig,
ViTForImageClassification,
ViTFeatureExtractor,
AutoModelForCausalLM,
LogitsProcessorList,
MinLengthLogitsProcessor,
StoppingCriteriaList,
MaxLengthCriteria,
ImageClassificationPipeline,
PerceiverForImageClassificationConvProcessing,
PerceiverFeatureExtractor,
VisionEncoderDecoderModel,
AutoTokenizer,
)
import json
import os
#get from local file spaces_info.py
from spaces_info import description, examples, initial_prompt_value
#some constants
API_URL = os.getenv("API_URL")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
##Bloom Inference API
API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom"
#HF_API_TOKEN = os.environ["HF_API_TOKEN"]
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
print(API_URL)
print(HF_API_TOKEN)
def query(payload):
print(payload)
response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"})
print(response)
return json.loads(response.content.decode("utf-8"))
def inference(input_sentence, max_length, sample_or_greedy, seed=42):
if sample_or_greedy == "Sample":
parameters = {
"max_new_tokens": max_length,
"top_p": 0.9,
"do_sample": True,
"seed": seed,
"early_stopping": False,
"length_penalty": 0.0,
"eos_token_id": None,
}
else:
parameters = {
"max_new_tokens": max_length,
"do_sample": False,
"seed": seed,
"early_stopping": False,
"length_penalty": 0.0,
"eos_token_id": None,
}
payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} }
data = query(payload)
if "error" in data:
return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>")
generation = data[0]["generated_text"].split(input_sentence, 1)[1]
print(generation)
'''
return (
input_sentence
+ prompt_to_generation
+ generation
+ after_generation,
data[0]["generated_text"],
"",
)
'''
return input_sentence + generation
def self_caption(image):
repo_name = "ydshieh/vit-gpt2-coco-en"
test_image = image
feature_extractor2 = ViTFeatureExtractor.from_pretrained(repo_name)
tokenizer = AutoTokenizer.from_pretrained(repo_name)
model2 = VisionEncoderDecoderModel.from_pretrained(repo_name)
pixel_values = feature_extractor2(test_image, return_tensors="pt").pixel_values
print("Pixel Values")
print(pixel_values)
# autoregressively generate text (using beam search or other decoding strategy)
generated_ids = model2.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True)
# decode into text
preds = tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
print("Predictions")
print(preds)
print("The preds type is : ",type(preds))
pred_keys = ["Prediction"]
pred_value = preds
pred_dictionary = dict(zip(pred_keys, pred_value))
print("Pred dictionary")
print(pred_dictionary)
preds = ' '.join(preds)
#inference(input_sentence, max_length, sample_or_greedy, seed=42)
story = inference(preds, 64, "Sample", 42)
return story
def classify_image(image):
config = ViTConfig(num_hidden_layers=12, hidden_size=768)
model = ViTForImageClassification(config)
#print(config)
feature_extractor = ViTFeatureExtractor()
# or, to load one that corresponds to a checkpoint on the hub:
#feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
#the following gets called by classify_image()
feature_extractor = PerceiverFeatureExtractor.from_pretrained("deepmind/vision-perceiver-conv")
model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv")
#google/vit-base-patch16-224, deepmind/vision-perceiver-conv
image_pipe = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
results = image_pipe(image)
print("RESULTS")
print(results)
# convert to format Gradio expects
output = {}
for prediction in results:
predicted_label = prediction['label']
score = prediction['score']
output[predicted_label] = score
print("OUTPUT")
print(output)
return output
image = gr.inputs.Image(type="pil")
label = gr.outputs.Label(num_top_classes=5)
examples = [ ["cats.jpg"], ["batter.jpg"],["drinkers.jpg"] ]
#examples = [ ["batter.jpg"] ]
title = "Generate a Story from an Image using BLOOM"
description = "Demo for classifying images with Perceiver IO. To use it, simply upload an image and click 'submit', a story is autogenerated as well, story generated using Bigscience/BLOOM"
article = "<p style='text-align: center'></p>"
img_info1 = gr.Interface(
fn=classify_image,
inputs=image,
outputs=label,
)
img_info2 = gr.Interface(
fn=self_caption,
inputs=image,
#outputs=label,
outputs = [
gr.outputs.Textbox(label = 'Story')
],
)
Parallel(img_info1,img_info2, inputs=image, title=title, description=description, examples=examples, enable_queue=True).launch(debug=True)