emoji-tagging / app.py
vincentclaes's picture
simplify tagging
90cb8eb
raw
history blame
No virus
2.12 kB
import os
import torch
import boto3
import gradio as gr
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
checkpoint = "vincentclaes/emoji-predictor"
adjectives = pd.read_table("./adjectives.txt", header=None)[0].to_list()
K = 10
THRESHOLD = 0.05
APP_NAME = "emoji-tagging"
BUCKET = "drift-pilot-ml-model"
processor = CLIPProcessor.from_pretrained(checkpoint)
model = CLIPModel.from_pretrained(checkpoint)
def log_inference():
if os.environ["CLIENT"]:
boto3.client("s3").put_object(
Body=more_binary_data,
Bucket=BUCKET,
Key=f"${APP_NAME}/",
)
def get_tag(emoji, tags="", expected="", model=model, processor=processor, K=K):
if tags:
tags = tags.strip().split(",")
else:
tags = adjectives
inputs = processor(
text=tags, images=emoji, return_tensors="pt", padding=True, truncation=True
)
outputs = model(**inputs)
# we take the softmax to get the label probabilities
probs = outputs.logits_per_text.softmax(dim=0)
probs_formatted = torch.tensor([prob[0] for prob in probs])
values, indices = probs_formatted.topk(K)
return "Tag (confidence): " + ", ".join(
[f"{tags[i]} ({round(v.item(), 2)})" for v, i in zip(values, indices) if v >= THRESHOLD]
)
title = "Tagging an Emoji"
description = """You provide an Emoji and our few-shot fine tuned CLIP model will suggest some tags that are appropriate.\n
- We use the [228 most common adjectives in english](https://grammar.yourdictionary.com/parts-of-speech/adjectives/list-of-adjective-words.html).\n
- We show max 10 tags and only when the confidence is higher than 5% (0.05)
"""
examples = [[f"emojis/{i}.png"] for i in range(32)]
text = gr.inputs.Textbox(
placeholder="Enter a text and we will try to predict an emoji..."
)
app = gr.Interface(
fn=get_tag,
inputs=[
gr.components.Image(type="pil", label="emoji"),
],
outputs=gr.Textbox(),
examples=examples,
examples_per_page=32,
title=title,
description=description,
)
if __name__ == "__main__":
app.launch()