""" Script for streamlit demo @author: AbinayaM02 """ # Install necessary libraries from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline import streamlit as st import json # Read the config with open("config.json") as f: config = json.loads(f.read()) # Set page layout st.set_page_config( page_title="Tamil Language Models", page_icon="U+270D", layout="wide", initial_sidebar_state="expanded" ) # Load the model @st.cache(allow_output_mutation=True) def load_model(model_name): with st.spinner('Waiting for the model to load.....'): model = AutoModelWithLMHead.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) return model, tokenizer # Side bar img = st.sidebar.image("images/tamil_logo.jpg", width=300) # Choose the model based on selection st.sidebar.title("கதை சொல்லி!") page = st.sidebar.selectbox(label="Select model", options=config["models"], help="Select the model to generate the text") data = st.sidebar.selectbox(label="Select data", options=config[page], help="Select the data on which the model is trained") if page == "Text Generation" and data == "Oscar + IndicNLP": st.sidebar.markdown( "[Model tracking on wandb](https://wandb.ai/wandb/hf-flax-gpt2-tamil/runs/watdq7ib/overview?workspace=user-abinayam)", unsafe_allow_html=True ) st.sidebar.markdown( "[Model card](https://huggingface.co/abinayam/gpt-2-tamil)", unsafe_allow_html=True ) elif page == "Text Generation" and data == "Oscar": st.sidebar.markdown( "[Model tracking on wandb](https://wandb.ai/abinayam/hf-flax-gpt-2-tamil/runs/1ddv4131/overview?workspace=user-abinayam)", unsafe_allow_html=True ) st.sidebar.markdown( "[Model card](https://huggingface.co/flax-community/gpt-2-tamil)", unsafe_allow_html=True ) # Main page st.title("Tamil Language Demos") st.markdown( "Built as part of the Flax/Jax Community week, this demo uses [GPT2 trained on Oscar dataset](https://huggingface.co/flax-community/gpt-2-tamil) " "and [GPT2 trained on Oscar & IndicNLP dataset] (https://huggingface.co/abinayam/gpt-2-tamil) " "to show language generation!" ) # Set default options for examples prompts = config["examples"] + ["Custom"] if page == 'Text Generation' and data == 'Oscar': st.header('Tamil text generation with GPT2') st.markdown('A simple demo using gpt-2-tamil model trained on Oscar dataset!') model, tokenizer = load_model(config[data]) elif page == 'Text Generation' and data == "Oscar + Indic Corpus": st.header('Tamil text generation with GPT2') st.markdown('A simple demo using gpt-2-tamil model trained on Oscar + IndicNLP dataset') model, tokenizer = load_model(config[data]) else: st.title('Tamil News classification with Finetuned GPT2') st.markdown('In progress') if page == "Text Generation": # Set default options prompt = st.selectbox('Examples', prompts, index=0) if prompt == "Custom": prompt_box = "", text = st.text_input( 'Add your custom text in Tamil', "", max_chars=1000) else: prompt_box = prompt text = st.text_input( 'Selected example in Tamil', prompt, max_chars=1000) max_len = st.slider('Select length of the sentence to generate', 25, 300, 100) gen_bt = st.button('Generate') # Generate text if gen_bt: try: with st.spinner('Generating...'): generator = pipeline('text-generation', model=model, tokenizer=tokenizer) seqs = generator(prompt_box, max_length=max_len)[0]['generated_text'] st.write(seqs) except Exception as e: st.exception(f'Exception: {e}')