File size: 714 Bytes
d6a1cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

import streamlit as st
from transformers import AutoTokenizer
import torch

@st.cache(allow_output_mutation=True)
def get_model():
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    return tokenizer
    
tokenizer = get_model()
bad_words = st.text_input("Words You Do Not Want Generated", " core lemon height time ")

def run_generate(bad_words):
  bad_words = bad_words.split()
  bad_word_ids = []
  for bad_word in bad_words: 
    bad_word = " " + bad_word
    ids = tokenizer(bad_word).input_ids
    bad_word_ids.append(ids)
  print(bad_word_ids)
  return bad_word_ids
  
if text:
    translated_text = run_generate(bad_words)
    st.write(translated_text if translated_text else "No translation found")