Jellyfish042 commited on
Commit
4d3d295
1 Parent(s): aaa21f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -1
app.py CHANGED
@@ -1,7 +1,24 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
  def text_processing(text):
4
- return text * 2
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  iface = gr.Interface(fn = text_processing, inputs='text', outputs=['text'], title='test', description='test space')
7
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+
4
+ tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
5
+ model = AutoModelForSeq2SeqLM.from_pretrained("./models/checkpoint-15000/")
6
+
7
 
8
  def text_processing(text):
9
+ inputs = [text]
10
+
11
+ # Tokenize and prepare the inputs for model
12
+ input_ids = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").input_ids.to(device)
13
+ attention_mask = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").attention_mask.to(device)
14
+
15
+ # Generate prediction
16
+ output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=512)
17
+
18
+ # Decode the prediction
19
+ decoded_output = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output]
20
+
21
+ return decoded_output[0]
22
 
23
  iface = gr.Interface(fn = text_processing, inputs='text', outputs=['text'], title='test', description='test space')
24