File size: 456 Bytes
ca404cd
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import streamlit as st

from transformers import AutoTokenizer, AutoModelForCausalLM

def load_model(model_name):
    with st.spinner('Waiting for the model to load.....'):
        # snapshot_download('flax-community/Sinhala-gpt2')
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)
    st.success('Model loaded!!')
    return model, tokenizer