File size: 971 Bytes
d6a1cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4c807e
0d39a4d
d6a1cf2
b4c807e
 
7fcb451
d6a1cf2
 
 
a2cb5b0
d6a1cf2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

import streamlit as st
from transformers import AutoTokenizer
import torch

@st.cache(allow_output_mutation=True)
def get_model():
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    return tokenizer
    
tokenizer = get_model()
bad_words = st.text_input("Words You Do Not Want Generated", " core lemon height time ")

def run_generate(bad_words):
  bad_words = bad_words.split()
  bad_word_ids = []
  for bad_word in bad_words: 
    bad_word = " " + bad_word
    ids = tokenizer(bad_word).input_ids
    ids = str(ids)
    ids = ids.replace("]", ": -30").replace("[", "").replace(", ", ":-30, ")
    bad_word_ids.append(ids)
  bad_word_ids = str(bad_word_ids)
  bad_word_ids = bad_word_ids.replace("['", "{").replace("']", "}").replace("'", "")
  bad_word_ids = bad_word_ids + ","
  print(bad_word_ids)
  return bad_word_ids
  
if bad_words:
    translated_text = run_generate(bad_words)
    st.write(translated_text if translated_text else "No translation found")