|
import io |
|
import matplotlib.pyplot as plt |
|
import requests |
|
import inflect |
|
from PIL import Image |
|
|
|
def load_image_from_url(url): |
|
return Image.open(requests.get(url, stream=True).raw) |
|
|
|
def render_results_in_image(in_pil_img, in_results): |
|
plt.figure(figsize=(16, 10)) |
|
plt.imshow(in_pil_img) |
|
|
|
ax = plt.gca() |
|
|
|
for prediction in in_results: |
|
|
|
x, y = prediction['box']['xmin'], prediction['box']['ymin'] |
|
w = prediction['box']['xmax'] - prediction['box']['xmin'] |
|
h = prediction['box']['ymax'] - prediction['box']['ymin'] |
|
|
|
ax.add_patch(plt.Rectangle((x, y), |
|
w, |
|
h, |
|
fill=False, |
|
color="green", |
|
linewidth=2)) |
|
ax.text( |
|
x, |
|
y, |
|
f"{prediction['label']}: {round(prediction['score']*100, 1)}%", |
|
color='red' |
|
) |
|
|
|
plt.axis("off") |
|
|
|
|
|
img_buf = io.BytesIO() |
|
plt.savefig(img_buf, format='png', |
|
bbox_inches='tight', |
|
pad_inches=0) |
|
img_buf.seek(0) |
|
modified_image = Image.open(img_buf) |
|
|
|
|
|
plt.close() |
|
|
|
return modified_image |
|
|
|
def summarize_predictions_natural_language(predictions): |
|
summary = {} |
|
p = inflect.engine() |
|
|
|
for prediction in predictions: |
|
label = prediction['label'] |
|
if label in summary: |
|
summary[label] += 1 |
|
else: |
|
summary[label] = 1 |
|
|
|
result_string = "In this image, there are " |
|
for i, (label, count) in enumerate(summary.items()): |
|
count_string = p.number_to_words(count) |
|
result_string += f"{count_string} {label}" |
|
if count > 1: |
|
result_string += "s" |
|
|
|
result_string += " " |
|
|
|
if i == len(summary) - 2: |
|
result_string += "and " |
|
|
|
|
|
result_string = result_string.rstrip(', ') + "." |
|
|
|
return result_string |
|
|
|
|
|
|
|
import warnings |
|
import logging |
|
from transformers import logging as hf_logging |
|
|
|
def ignore_warnings(): |
|
|
|
warnings.filterwarnings("ignore", message="Some weights of the model checkpoint") |
|
warnings.filterwarnings("ignore", message="Could not find image processor class") |
|
warnings.filterwarnings("ignore", message="The `max_size` parameter is deprecated") |
|
|
|
|
|
logging.basicConfig(level=logging.ERROR) |
|
hf_logging.set_verbosity_error() |
|
|
|
|