File size: 3,989 Bytes
40e9898
 
 
 
 
36338f2
40e9898
fb12737
40e9898
 
 
36338f2
40e9898
 
36338f2
 
1a15874
36338f2
 
 
40e9898
 
 
36338f2
 
 
 
 
40e9898
36338f2
2b02259
40e9898
36338f2
e4461ed
 
 
 
 
 
 
6673aaa
 
1c9ee74
 
 
 
cab7f25
 
1c9ee74
6673aaa
 
cab7f25
1c9ee74
 
 
cab7f25
 
6673aaa
36338f2
 
2b02259
40e9898
e4461ed
 
2b02259
40e9898
36338f2
e4461ed
 
 
36338f2
2b02259
e4461ed
36338f2
 
2b02259
e4461ed
36338f2
5a923ef
 
 
 
 
36338f2
5a923ef
e4461ed
5a923ef
 
e4461ed
 
 
5a923ef
 
 
 
 
 
 
36338f2
e4461ed
5a923ef
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
""" Script for streamlit demo
    @author: AbinayaM02
"""

# Install necessary libraries
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline
import streamlit as st
import json

# Read the config
with open("config.json") as f:
    config = json.loads(f.read())

# Set page layout
st.set_page_config(
        page_title="Tamil Language Models",
        page_icon="U+270D",
        layout="wide",
        initial_sidebar_state="expanded"
    )

# Load the model
@st.cache(allow_output_mutation=True)
def load_model(model_name):
    with st.spinner('Waiting for the model to load.....'):
        model = AutoModelWithLMHead.from_pretrained(model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

# Side bar
img = st.sidebar.image("images/tamil_logo.jpg", width=300)

# Choose the model based on selection
st.sidebar.title("கதை சொல்லி!")
page = st.sidebar.selectbox(label="Select model", 
                            options=config["models"],
                            help="Select the model to generate the text")
data = st.sidebar.selectbox(label="Select data",
                            options=config[page],
                            help="Select the data on which the model is trained")
if page == "Text Generation" and data == "Oscar + IndicNLP":
    st.sidebar.markdown(
        "[Model tracking on wandb](https://wandb.ai/wandb/hf-flax-gpt2-tamil/runs/watdq7ib/overview?workspace=user-abinayam)",
         unsafe_allow_html=True
    )
    st.sidebar.markdown(
        "[Model card](https://huggingface.co/abinayam/gpt-2-tamil)",
        unsafe_allow_html=True
    )
elif page == "Text Generation" and data == "Oscar":
    st.sidebar.markdown(
        "[Model tracking on wandb](https://wandb.ai/abinayam/hf-flax-gpt-2-tamil/runs/1ddv4131/overview?workspace=user-abinayam)",
         unsafe_allow_html=True
    )
    st.sidebar.markdown(
        "[Model card](https://huggingface.co/flax-community/gpt-2-tamil)",
        unsafe_allow_html=True
    )

# Main page
st.title("Tamil Language Demos")
st.markdown(
    "Built as part of the Flax/Jax Community week, this demo uses [GPT2 trained on Oscar dataset](https://huggingface.co/flax-community/gpt-2-tamil) "
    "and [GPT2 trained on Oscar & IndicNLP dataset] (https://huggingface.co/abinayam/gpt-2-tamil) "
    "to show language generation!"
)

# Set default options for examples
prompts = config["examples"] + ["Custom"]

if page == 'Text Generation' and data == 'Oscar':
    st.header('Tamil text generation with GPT2')
    st.markdown('A simple demo using gpt-2-tamil model trained on Oscar dataset!')
    model, tokenizer = load_model(config[data])
elif page == 'Text Generation' and data == "Oscar + Indic Corpus":
    st.header('Tamil text generation with GPT2')
    st.markdown('A simple demo using gpt-2-tamil model trained on Oscar + IndicNLP dataset')
    model, tokenizer = load_model(config[data])
else:
    st.title('Tamil News classification with Finetuned GPT2')
    st.markdown('In progress')

if page == "Text Generation":
    # Set default options
    prompt = st.selectbox('Examples', prompts, index=0)
    if prompt == "Custom":
        prompt_box = "",
        text = st.text_input(
        'Add your custom text in Tamil',
        "",
        max_chars=1000)
    else:
        prompt_box = prompt
        text = st.text_input(
        'Selected example in Tamil',
        prompt,
        max_chars=1000)
    max_len = st.slider('Select length of the sentence to generate', 25, 300, 100)
    gen_bt = st.button('Generate')

    # Generate text
    if gen_bt:
            try:
                with st.spinner('Generating...'):
                    generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
                    seqs = generator(prompt_box, max_length=max_len)[0]['generated_text']
                st.write(seqs)
            except Exception as e:
                st.exception(f'Exception: {e}')