Spaces:
Runtime error
Runtime error
""" Script for streamlit demo | |
@author: AbinayaM02 | |
""" | |
# Install necessary libraries | |
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline | |
import streamlit as st | |
import json | |
# Read the config | |
with open("config.json") as f: | |
config = json.loads(f.read()) | |
# Set page layout | |
st.set_page_config( | |
page_title="Tamil Language Models", | |
page_icon="U+270D", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Load the model | |
def load_model(model_name): | |
with st.spinner('Waiting for the model to load.....'): | |
model = AutoModelWithLMHead.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
return model, tokenizer | |
# Side bar | |
img = st.sidebar.image("images/tamil_logo.jpg", width=300) | |
# Choose the model based on selection | |
st.sidebar.title("கதை சொல்லி!") | |
page = st.sidebar.selectbox(label="Select model", | |
options=config["models"], | |
help="Select the model to generate the text") | |
data = st.sidebar.selectbox(label="Select data", | |
options=config[page], | |
help="Select the data on which the model is trained") | |
if page == "Text Generation" and data == "Oscar + IndicNLP": | |
st.sidebar.markdown( | |
"[Model tracking on wandb](https://wandb.ai/wandb/hf-flax-gpt2-tamil/runs/watdq7ib/overview?workspace=user-abinayam)" | |
"[Model card](https://huggingface.co/abinayam/gpt-2-tamil)" | |
) | |
elif page == "Text Generation" and data == "Oscar": | |
st.sidebar.markdown( | |
"[Model tracking on wandb](https://wandb.ai/abinayam/hf-flax-gpt-2-tamil/runs/1ddv4131/overview?workspace=user-abinayam)", | |
"[Model card](https://huggingface.co/flax-community/gpt-2-tamil)" | |
) | |
# Main page | |
st.title("Tamil Language Demos") | |
st.markdown( | |
"Built as part of the Flax/Jax Community week, this demo uses [GPT2 trained on Oscar dataset](https://huggingface.co/flax-community/gpt-2-tamil) " | |
"and [GPT2 trained on Oscar & IndicNLP dataset] (https://huggingface.co/abinayam/gpt-2-tamil) " | |
"to show language generation!" | |
) | |
# Set default options for examples | |
prompts = config["examples"] + ["Custom"] | |
if page == 'Text Generation' and data == 'Oscar': | |
st.header('Tamil text generation with GPT2') | |
st.markdown('A simple demo using gpt-2-tamil model trained on Oscar dataset!') | |
model, tokenizer = load_model(config[data]) | |
elif page == 'Text Generation' and data == "Oscar + Indic Corpus": | |
st.header('Tamil text generation with GPT2') | |
st.markdown('A simple demo using gpt-2-tamil model trained on Oscar + IndicNLP dataset') | |
model, tokenizer = load_model(config[data]) | |
else: | |
st.title('Tamil News classification with Finetuned GPT2') | |
st.markdown('In progress') | |
if page == "Text Generation": | |
# Set default options | |
prompt = st.selectbox('Examples', prompts, index=0) | |
if prompt == "Custom": | |
prompt_box = "", | |
text = st.text_input( | |
'Add your custom text in Tamil', | |
"", | |
max_chars=1000) | |
else: | |
prompt_box = prompt | |
text = st.text_input( | |
'Selected example in Tamil', | |
prompt, | |
max_chars=1000) | |
max_len = st.slider('Select length of the sentence to generate', 25, 300, 100) | |
gen_bt = st.button('Generate') | |
# Generate text | |
if gen_bt: | |
try: | |
with st.spinner('Generating...'): | |
generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
seqs = generator(prompt_box, max_length=max_len)[0]['generated_text'] | |
st.write(seqs) | |
except Exception as e: | |
st.exception(f'Exception: {e}') | |