emoji-tagging / app.py
vincentclaes's picture
simplify tagging
90cb8eb
import os
import torch
import boto3
import gradio as gr
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
checkpoint = "vincentclaes/emoji-predictor"
adjectives = pd.read_table("./adjectives.txt", header=None)[0].to_list()
K = 10
THRESHOLD = 0.05
APP_NAME = "emoji-tagging"
BUCKET = "drift-pilot-ml-model"
processor = CLIPProcessor.from_pretrained(checkpoint)
model = CLIPModel.from_pretrained(checkpoint)
def log_inference():
if os.environ["CLIENT"]:
boto3.client("s3").put_object(
Body=more_binary_data,
Bucket=BUCKET,
Key=f"${APP_NAME}/",
)
def get_tag(emoji, tags="", expected="", model=model, processor=processor, K=K):
if tags:
tags = tags.strip().split(",")
else:
tags = adjectives
inputs = processor(
text=tags, images=emoji, return_tensors="pt", padding=True, truncation=True
)
outputs = model(**inputs)
# we take the softmax to get the label probabilities
probs = outputs.logits_per_text.softmax(dim=0)
probs_formatted = torch.tensor([prob[0] for prob in probs])
values, indices = probs_formatted.topk(K)
return "Tag (confidence): " + ", ".join(
[f"{tags[i]} ({round(v.item(), 2)})" for v, i in zip(values, indices) if v >= THRESHOLD]
)
title = "Tagging an Emoji"
description = """You provide an Emoji and our few-shot fine tuned CLIP model will suggest some tags that are appropriate.\n
- We use the [228 most common adjectives in english](https://grammar.yourdictionary.com/parts-of-speech/adjectives/list-of-adjective-words.html).\n
- We show max 10 tags and only when the confidence is higher than 5% (0.05)
"""
examples = [[f"emojis/{i}.png"] for i in range(32)]
text = gr.inputs.Textbox(
placeholder="Enter a text and we will try to predict an emoji..."
)
app = gr.Interface(
fn=get_tag,
inputs=[
gr.components.Image(type="pil", label="emoji"),
],
outputs=gr.Textbox(),
examples=examples,
examples_per_page=32,
title=title,
description=description,
)
if __name__ == "__main__":
app.launch()