File size: 3,094 Bytes
d1e307b
 
 
 
 
f6f9509
e5b30dc
bc7f7c5
 
 
 
 
 
 
 
 
 
794ff23
bc7f7c5
 
 
 
 
 
 
aca3de7
bc7f7c5
63d5e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a21c6f2
 
883ed1c
a21c6f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import torch

TOP_N = 5
DEFAULT_MODEL = "amazon-sagemaker-community/xlm-roberta-en-ru-emoji-v2"

def preprocess(text):
    new_text = []
    for t in text.split(" "):
        t = '@user' if t.startswith('@') and len(t) > 1 else t
        t = 'http' if t.startswith('http') else t
        new_text.append(t)
    return " ".join(new_text)


def get_top_emojis(text, tokenizer, model, top_n=TOP_N):
    preprocessed = preprocess(text)
    inputs = tokenizer(preprocessed, return_tensors="pt")
    preds = model(**inputs).logits
    scores = torch.nn.functional.softmax(preds, dim=-1).detach().numpy()
    ranking = np.argsort(scores)
    ranking = ranking.squeeze()[::-1][:top_n]
    emojis = [model.config.id2label[i] for i in ranking]
    return '\t'.join(map(str, emojis))


cur_model_name = DEFAULT_MODEL
print("cur_model", cur_model_name)

tokenizer = AutoTokenizer.from_pretrained(cur_model_name)
model = AutoModelForSequenceClassification.from_pretrained(cur_model_name)

st.set_page_config(  # Alternate names: setup_page, page, layout
    layout="centered",  # Can be "centered" or "wide". In the future also "dashboard", etc.
    initial_sidebar_state="auto",  # Can be "auto", "expanded", "collapsed"
    page_title="Emoji-motion!",  # String or None. Strings get appended with "• Streamlit".
    page_icon=None,  # String, anything supported by st.image, or None.
)

st.title('Emoji-motion!')

example_prompts = [
    "Today is going to be awesome!",
    "Pity those who don't feel anything at all.",
    "I envy people that know love.",
    "Nature is so beautiful"]

def main():
    example = st.selectbox("Choose an example", example_prompts)
    
    # Take the message which needs to be processed
    message = st.text_area("...or paste some text to see the model's predictions", example)
    # st.title(message)
    st.text('')
    models_to_choose = [
        "amazon-sagemaker-community/xlm-roberta-en-ru-emoji-v2",
        "AlekseyDorkin/xlm-roberta-en-ru-emoji"
    ]
    
    model_name = st.selectbox("Choose a model", models_to_choose)
    if model_name != cur_model_name:
        print("reloading model")
        cur_model_name = model_name
        tokenizer = AutoTokenizer.from_pretrained(cur_model_name)
        model = AutoModelForSequenceClassification.from_pretrained(cur_model_name)
        
    
    # Define function to run when submit is clicked
    def submit(message):
        if len(message) > 0:
            st.header(get_top_emojis(message, tokenizer=tokenizer, model=model))
        else:
            st.error("The text can't be empty")
    
    # Run algo when submit button is clicked
    if st.button('Submit'):
        submit(message)
    
    st.text('')
    st.markdown(
    '''<span style="color:blue; font-size:10px">App created by [@AlekseyDorkin](https://huggingface.co/AlekseyDorkin)
     and [@akshay7](https://huggingface.co/akshay7)</span>''', 
     unsafe_allow_html=True,
    )
    
if __name__ == "__main__":
    main()