mrm8488 commited on
Commit
4b7a8dd
1 Parent(s): 45e01da

Change generation interface

Browse files
Files changed (1) hide show
  1. app.py +5 -20
app.py CHANGED
@@ -1,33 +1,18 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import T5Tokenizer, AutoModelForCausalLM
4
  from utils import translate_from_jp_to_en
5
 
6
  tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt-1b")
7
  model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-1b")
8
 
 
9
 
10
  def generate(text, max_length=512):
 
 
11
 
12
- token_ids = tokenizer.encode(
13
- text, add_special_tokens=False, return_tensors="pt")
14
-
15
- with torch.no_grad():
16
- output_ids = model.generate(
17
- token_ids,
18
- max_length=max_length,
19
- do_sample=True,
20
- top_k=500,
21
- top_p=0.95,
22
- #pad_token_id=tokenizer.pad_token_id,
23
- #bos_token_id=tokenizer.bos_token_id,
24
- #eos_token_id=tokenizer.eos_token_id,
25
- #bad_word_ids=[[tokenizer.unk_token_id]]
26
- early_stopping=False,
27
- )
28
-
29
- output = tokenizer.decode(output_ids.tolist()[0], skip_specual_tokens=True)
30
- return output, translate_from_jp_to_en(output)
31
 
32
 
33
  title = "JP GPT Demo"
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import T5Tokenizer, AutoModelForCausalLM, pipeline
4
  from utils import translate_from_jp_to_en
5
 
6
  tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt-1b")
7
  model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-1b")
8
 
9
+ generator = pipeline("text-generation", tokenizer=tokenizer, model=model)
10
 
11
  def generate(text, max_length=512):
12
+ out = generator(text, do_sample=True, max_length=max_length, num_return_sequences=1)
13
+ text = out[0]['generated_text']
14
 
15
+ return text, translate_from_jp_to_en(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  title = "JP GPT Demo"