# coding=utf-8 # Copyright 2023 The GIRT Authors. # Lint as: python3 # This space is built based on AMR-KELEG/ALDi and cis-lmu/GlotLID space. # GIRT Space from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import streamlit as st import base64 @st.cache_data def render_svg(svg): """Renders the given svg string.""" b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") html = rf'

' c = st.container() c.write(html, unsafe_allow_html=True) @st.cache_resource def load_model(model_name): model = AutoModelForSeq2SeqLM.from_pretrained(model_name) return model @st.cache_resource def load_tokenizer(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) return tokenizer with st.spinner(text="Please wait while the model is loading...."): model = load_model('nafisehNik/girt-t5-base') tokenizer = load_tokenizer('nafisehNik/girt-t5-base') def compute(sample, num_beams, length_penalty, early_stopping, max_length, min_length): inputs = tokenizer(sample, return_tensors="pt").to('cpu') outputs = model.generate( **inputs, num_beams=num_beams, num_return_sequences=1, length_penalty=length_penalty, no_repeat_ngram_size=2, early_stopping=early_stopping, max_length=max_length, min_length=min_length).to('cpu') generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False) generated_text = generated_texts[0] replace_dict = { '\n ': '\n', '': '', ' ': '', '': '', '': '' } postprocess_text = generated_text for key, value in replace_dict.items(): postprocess_text = postprocess_text.replace(key, value) return postprocess_text st.markdown("[![Duplicate Space](https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14)](https://huggingface.co/spaces/nafisehNik/girt-space?duplicate=true)") render_svg(open("assets/logo.svg").read()) tab1, tab2 = st.tabs(["Design GitHub Issue Template", "Manual Prompt"]) with tab1: pass with tab2: sent = st.text_input( "Sentence:", placeholder="Enter a prompt.", on_change=None ) # TODO: Check if this is needed! clicked = st.button("Submit") if sent: res = compute(sent, num_beams=2, length_penalty=1.0, early_stopping=True, max_length=300, min_length=20) st.code(res, language="python")