File size: 3,973 Bytes
d1e307b
 
 
 
 
f6f9509
d1e307b
a21c6f2
eda5df5
 
02af3b7
 
 
d1e307b
 
 
 
 
 
02af3b7
d1e307b
02af3b7
d1e307b
09302f5
 
 
 
 
 
 
 
 
 
02af3b7
d1e307b
a21c6f2
883ed1c
d1e307b
a21c6f2
d1e307b
 
a21c6f2
 
 
 
 
d1e307b
 
 
10dd1b8
d1e307b
 
 
 
 
 
 
 
5c4e25b
 
 
 
 
 
d1e307b
5c4e25b
d1e307b
 
 
 
 
 
 
 
a21c6f2
 
d1e307b
 
a21c6f2
 
d1e307b
 
a21c6f2
d1e307b
a21c6f2
d1e307b
a21c6f2
d1e307b
a21c6f2
 
 
 
 
 
d1e307b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import torch

TOP_N = 5

def main():
    print("cur_model", cur_model_name)
    
    tokenizer = AutoTokenizer.from_pretrained(cur_model_name)
    model = AutoModelForSequenceClassification.from_pretrained(cur_model_name)
    
    st.set_page_config(  # Alternate names: setup_page, page, layout
        layout="centered",  # Can be "centered" or "wide". In the future also "dashboard", etc.
        initial_sidebar_state="auto",  # Can be "auto", "expanded", "collapsed"
        page_title="Emoji-motion!",  # String or None. Strings get appended with "• Streamlit".
        page_icon=None,  # String, anything supported by st.image, or None.
    )
    
    st.title('Emoji-motion!')
    
    example_prompts = [
        "it's pretty depressing when u hit pan on ur favourite highlighter",
        "After what just happened. In need to smoke.",
        "I've never been happier. I'm laying awake as I watch @user sleep. Thanks for making me happy again, babe.",
        "@user is the man",
        "Поприветствуем моего нового читателя @user",
        "сегодня у одной крутой бичи день рождения! @user поздравляю тебя с днем рождения! будь самой-самой счастливой,красота:* море любви тебе",
        "Никогда не явствовала себя ужаснее, чем сейчас:( я просто раздавленна",
        "Самое ужасное - это ожидание результатов",
        "печально что заряд одинаково фигово держится(",
    ]
    

    example = st.selectbox("Choose an example", example_prompts)
    
    # Take the message which needs to be processed
    message = st.text_area("...or paste some text to see the model's predictions", example)
    # st.title(message)
    st.text('')
    models_to_choose = [
        "amazon-sagemaker-community/xlm-roberta-en-ru-emoji-v2",
        "AlekseyDorkin/xlm-roberta-en-ru-emoji"
    ]
    
    BASE_MODEL = st.selectbox("Choose a model", models_to_choose)
    TOP_N = 5


    def preprocess(text):
        new_text = []
        for t in text.split(" "):
            t = '@user' if t.startswith('@') and len(t) > 1 else t
            t = 'http' if t.startswith('http') else t
            new_text.append(t)
        return " ".join(new_text)

    @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
    def load_model():
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
        model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL)
        return model, tokenizer

    def get_top_emojis(text, top_n=TOP_N):
        model, tokenizer = load_model()
        preprocessed = preprocess(text)
        inputs = tokenizer(preprocessed, return_tensors="pt")
        preds = model(**inputs).logits
        scores = torch.nn.functional.softmax(preds, dim=-1).detach().numpy()
        ranking = np.argsort(scores)
        ranking = ranking.squeeze()[::-1][:top_n]
        emojis = [model.config.id2label[i] for i in ranking]
        return ', '.join(map(str, emojis))
        
    
    # Define function to run when submit is clicked
    def submit(message):
        if len(message) > 0:
            st.header(get_top_emojis(message, tokenizer=tokenizer, model=model))
        else:
            st.error("The text can't be empty")
    
    # Run algo when submit button is clicked
    if st.button('Submit'):
        submit(message)
    
    st.text('')
    st.markdown(
    '''<span style="color:blue; font-size:10px">App created by [@AlekseyDorkin](https://huggingface.co/AlekseyDorkin)
     and [@akshay7](https://huggingface.co/akshay7)</span>''', 
     unsafe_allow_html=True,
    )
    
if __name__ == "__main__":
    main()