Spaces:
Build error
Build error
| import os | |
| import torch | |
| import boto3 | |
| import gradio as gr | |
| import pandas as pd | |
| from transformers import CLIPProcessor, CLIPModel | |
| checkpoint = "vincentclaes/emoji-predictor" | |
| adjectives = pd.read_table("./adjectives.txt", header=None)[0].to_list() | |
| K = 10 | |
| THRESHOLD = 0.05 | |
| APP_NAME = "emoji-tagging" | |
| BUCKET = "drift-pilot-ml-model" | |
| processor = CLIPProcessor.from_pretrained(checkpoint) | |
| model = CLIPModel.from_pretrained(checkpoint) | |
| def log_inference(): | |
| if os.environ["CLIENT"]: | |
| boto3.client("s3").put_object( | |
| Body=more_binary_data, | |
| Bucket=BUCKET, | |
| Key=f"${APP_NAME}/", | |
| ) | |
| def get_tag(emoji, tags="", expected="", model=model, processor=processor, K=K): | |
| if tags: | |
| tags = tags.strip().split(",") | |
| else: | |
| tags = adjectives | |
| inputs = processor( | |
| text=tags, images=emoji, return_tensors="pt", padding=True, truncation=True | |
| ) | |
| outputs = model(**inputs) | |
| # we take the softmax to get the label probabilities | |
| probs = outputs.logits_per_text.softmax(dim=0) | |
| probs_formatted = torch.tensor([prob[0] for prob in probs]) | |
| values, indices = probs_formatted.topk(K) | |
| return "Tag (confidence): " + ", ".join( | |
| [f"{tags[i]} ({round(v.item(), 2)})" for v, i in zip(values, indices) if v >= THRESHOLD] | |
| ) | |
| title = "Tagging an Emoji" | |
| description = """You provide an Emoji and our few-shot fine tuned CLIP model will suggest some tags that are appropriate.\n | |
| - We use the [228 most common adjectives in english](https://grammar.yourdictionary.com/parts-of-speech/adjectives/list-of-adjective-words.html).\n | |
| - We show max 10 tags and only when the confidence is higher than 5% (0.05) | |
| """ | |
| examples = [[f"emojis/{i}.png"] for i in range(32)] | |
| text = gr.inputs.Textbox( | |
| placeholder="Enter a text and we will try to predict an emoji..." | |
| ) | |
| app = gr.Interface( | |
| fn=get_tag, | |
| inputs=[ | |
| gr.components.Image(type="pil", label="emoji"), | |
| ], | |
| outputs=gr.Textbox(), | |
| examples=examples, | |
| examples_per_page=32, | |
| title=title, | |
| description=description, | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |