Tsgs_Extractor / demo.py
ShahzainHaider's picture
Upload folder using huggingface_hub
d2e2636 verified
import requests
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
import torch
import math
import streamlit as st
import matplotlib.pyplot as plt
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
from transformers import BlipProcessor, BlipForConditionalGeneration
from groq import Groq
import re
import json
# Initialize Groq API key and client
GROQ_API_KEY = "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"
client = Groq(api_key = GROQ_API_KEY)
# Initialize models and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
torch.set_grad_enabled(False)
# COCO classes and colors
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
# Image transformation
transform = T.Compose([T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# Helper functions
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
def plot_results(pil_img, prob, boxes):
plt.figure(figsize=(16,10))
plt.imshow(pil_img)
ax = plt.gca()
classes_predicted = []
colors = COLORS * 100
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
cl = p.argmax()
text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
classes_predicted.append(CLASSES[cl])
ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
st.pyplot(plt)
def get_caption(img_url):
raw_image = Image.open(img_url).convert('RGB')
inputs = processor(raw_image, return_tensors="pt")
out = caption_model.generate(**inputs)
return str(processor.decode(out[0], skip_special_tokens=True))
def get_objects(url):
im = Image.open(url)
img = transform(im).unsqueeze(0)
outputs = model(img)
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
plot_results(im, probas[keep], bboxes_scaled)
return [CLASSES[p.argmax()] for p in probas[keep]]
def get_tags(text, objects):
system_prompt = """
<SystemPrompt>
Extract Tags from the provided text.
The Tags that will be used to search.
<OutputFormat>
Format the output in the following JSON structure
{
"tags" : [* list of tags here*]
}
</OutputFormat>
</SystemPrompt>
"""
try:
user_prompt = f"Extract the Tags from this text:\n{text}\n{objects}"
chat_completion = client.chat.completions.create(
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
model="llama3-8b-8192", response_format={"type": "json_object"},
stream=False
)
json_data = json.loads(chat_completion.choices[0].message.content)
return json_data['tags'], chat_completion.usage.total_tokens * 0.00000005
except Exception as e:
st.error(f"Exception | get_tags | {str(e)}")
# Main image to tags function
def image_to_tags(image):
image = Image.fromarray(image)
image.save("saved_image.png")
generated_caption = get_caption('saved_image.png')
objects = get_objects('saved_image.png')
tags, cost = get_tags(generated_caption, ", ".join(objects))
return ", ".join(tags), generated_caption, ", ".join(objects), cost
# Streamlit app
st.title("Image Tagging App")
st.write("Upload an image and get captions, object detection results, and associated tags.")
# Image upload
uploaded_image = st.file_uploader("Choose an Image", type=["jpg", "png", "jpeg"])
if uploaded_image is not None:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image.', use_column_width=True)
# Generate tags, caption, objects, and cost
tags, caption, objects, cost = image_to_tags(image)
# Display results
st.subheader("Predicted Tags:")
st.write(tags)
st.subheader("Caption:")
st.write(caption)
st.subheader("Objects Detected:")
st.write(objects)
st.subheader("Cost:")
st.write(f"${cost:.6f}")