Abinaya Mahendiran
Updated app
6673aaa
raw
history blame
3.8 kB
""" Script for streamlit demo
@author: AbinayaM02
"""
# Install necessary libraries
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline
import streamlit as st
import json
# Read the config
with open("config.json") as f:
config = json.loads(f.read())
# Set page layout
st.set_page_config(
page_title="Tamil Language Models",
page_icon="U+270D",
layout="wide",
initial_sidebar_state="expanded"
)
# Load the model
@st.cache(allow_output_mutation=True)
def load_model(model_name):
with st.spinner('Waiting for the model to load.....'):
model = AutoModelWithLMHead.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
# Side bar
img = st.sidebar.image("images/tamil_logo.jpg", width=300)
# Choose the model based on selection
st.sidebar.title("கதை சொல்லி!")
page = st.sidebar.selectbox(label="Select model",
options=config["models"],
help="Select the model to generate the text")
data = st.sidebar.selectbox(label="Select data",
options=config[page],
help="Select the data on which the model is trained")
if page == "Text Generation" and data == "Oscar + IndicNLP":
st.sidebar.markdown(
"[Model tracking on wandb](https://wandb.ai/wandb/hf-flax-gpt2-tamil/runs/watdq7ib/overview?workspace=user-abinayam)"
"[Model card](https://huggingface.co/abinayam/gpt-2-tamil)"
)
elif page == "Text Generation" and data == "Oscar":
st.sidebar.markdown(
"[Model tracking on wandb](https://wandb.ai/abinayam/hf-flax-gpt-2-tamil/runs/1ddv4131/overview?workspace=user-abinayam)",
"[Model card](https://huggingface.co/flax-community/gpt-2-tamil)"
)
# Main page
st.title("Tamil Language Demos")
st.markdown(
"Built as part of the Flax/Jax Community week, this demo uses [GPT2 trained on Oscar dataset](https://huggingface.co/flax-community/gpt-2-tamil) "
"and [GPT2 trained on Oscar & IndicNLP dataset] (https://huggingface.co/abinayam/gpt-2-tamil) "
"to show language generation!"
)
# Set default options for examples
prompts = config["examples"] + ["Custom"]
if page == 'Text Generation' and data == 'Oscar':
st.header('Tamil text generation with GPT2')
st.markdown('A simple demo using gpt-2-tamil model trained on Oscar dataset!')
model, tokenizer = load_model(config[data])
elif page == 'Text Generation' and data == "Oscar + Indic Corpus":
st.header('Tamil text generation with GPT2')
st.markdown('A simple demo using gpt-2-tamil model trained on Oscar + IndicNLP dataset')
model, tokenizer = load_model(config[data])
else:
st.title('Tamil News classification with Finetuned GPT2')
st.markdown('In progress')
if page == "Text Generation":
# Set default options
prompt = st.selectbox('Examples', prompts, index=0)
if prompt == "Custom":
prompt_box = "",
text = st.text_input(
'Add your custom text in Tamil',
"",
max_chars=1000)
else:
prompt_box = prompt
text = st.text_input(
'Selected example in Tamil',
prompt,
max_chars=1000)
max_len = st.slider('Select length of the sentence to generate', 25, 300, 100)
gen_bt = st.button('Generate')
# Generate text
if gen_bt:
try:
with st.spinner('Generating...'):
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
seqs = generator(prompt_box, max_length=max_len)[0]['generated_text']
st.write(seqs)
except Exception as e:
st.exception(f'Exception: {e}')