Spaces:
Runtime error
Runtime error
File size: 2,124 Bytes
b8821f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import os
import torch
import boto3
import gradio as gr
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
checkpoint = "vincentclaes/emoji-predictor"
adjectives = pd.read_table("./adjectives.txt", header=None)[0].to_list()
K = 10
THRESHOLD = 0.05
APP_NAME = "emoji-tagging"
BUCKET = "drift-pilot-ml-model"
processor = CLIPProcessor.from_pretrained(checkpoint)
model = CLIPModel.from_pretrained(checkpoint)
def log_inference():
if os.environ["CLIENT"]:
boto3.client("s3").put_object(
Body=more_binary_data,
Bucket=BUCKET,
Key=f"${APP_NAME}/",
)
def get_tag(emoji, tags="", expected="", model=model, processor=processor, K=K):
if tags:
tags = tags.strip().split(",")
else:
tags = adjectives
inputs = processor(
text=tags, images=emoji, return_tensors="pt", padding=True, truncation=True
)
outputs = model(**inputs)
# we take the softmax to get the label probabilities
probs = outputs.logits_per_text.softmax(dim=0)
probs_formatted = torch.tensor([prob[0] for prob in probs])
values, indices = probs_formatted.topk(K)
return "Tag (confidence): " + ", ".join(
[f"{tags[i]} ({round(v.item(), 2)})" for v, i in zip(values, indices) if v >= THRESHOLD]
)
title = "Tagging an Emoji"
description = """You provide an Emoji and our few-shot fine tuned CLIP model will suggest some tags that are appropriate.\n
- We use the [228 most common adjectives in english](https://grammar.yourdictionary.com/parts-of-speech/adjectives/list-of-adjective-words.html).\n
- We show max 10 tags and only when the confidence is higher than 5% (0.05)
"""
examples = [[f"emojis/{i}.png"] for i in range(32)]
text = gr.inputs.Textbox(
placeholder="Enter a text and we will try to predict an emoji..."
)
app = gr.Interface(
fn=get_tag,
inputs=[
gr.components.Image(type="pil", label="emoji"),
],
outputs=gr.Textbox(),
examples=examples,
examples_per_page=32,
title=title,
description=description,
)
if __name__ == "__main__":
app.launch()
|