# coding=utf-8 # Copyright 2023 The GIRT Authors. # Lint as: python3 # This space is built based on AMR-KELEG/ALDi and cis-lmu/GlotLID space. # GIRT Space from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import streamlit as st import base64 @st.cache_data def render_svg(svg): """Renders the given svg string.""" b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") html = rf'
' c = st.container() c.write(html, unsafe_allow_html=True) @st.cache_resource def load_model(model_name): model = AutoModelForSeq2SeqLM.from_pretrained(model_name) return model @st.cache_resource def load_tokenizer(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) return tokenizer with st.spinner(text="Please wait while the model is loading...."): model = load_model('nafisehNik/girt-t5-base') tokenizer = load_tokenizer('nafisehNik/girt-t5-base') def compute(sample, num_beams, length_penalty, early_stopping, max_length, min_length): inputs = tokenizer(sample, return_tensors="pt").to('cpu') outputs = model.generate( **inputs, num_beams=num_beams, num_return_sequences=1, length_penalty=length_penalty, no_repeat_ngram_size=2, early_stopping=early_stopping, max_length=max_length, min_length=min_length).to('cpu') generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False) generated_text = generated_texts[0] replace_dict = { '\n ': '\n', '': '', '