Charles1973 commited on
Commit
99ac71e
·
1 Parent(s): f4f6d4a

app.py修正

Browse files
Files changed (1) hide show
  1. app.py +25 -5
app.py CHANGED
@@ -22,13 +22,33 @@ import torch
22
 
23
  # トークナイザーとモデルの準備
24
  tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
25
- model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
 
26
 
27
  # 推論の実行
 
 
 
 
28
  def Chat(prompt):
29
- input = tokenizer.encode(prompt, return_tensors="pt")
30
- output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=5)
31
- return tokenizer.batch_decode(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- app = gr.Interface(fn=Chat, inputs=gr.Textbox(lines=3, placeholder="文章を入力してください"), outputs="text" , title="りんな GPT-2 medium")
34
  app.launch()
 
22
 
23
  # トークナイザーとモデルの準備
24
  tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
25
+ # model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
26
+ model = AutoModelForCausalLM.from_pretrained("output/")
27
 
28
  # 推論の実行
29
+ #def Chat(prompt):
30
+ # input = tokenizer.encode(prompt, return_tensors="pt")
31
+ # output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=5)
32
+ # return tokenizer.batch_decode(output)
33
  def Chat(prompt):
34
+ num = 3
35
+ input_ids = tokenizer.encode(prompt, return_tensors="pt",add_special_tokens=False).to(device)
36
+ #with torch.no_grad():
37
+ output = model.generate(
38
+ input_ids,
39
+ max_length=300, # 最長の文章長
40
+ min_length=100, # 最短の文章長
41
+ do_sample=True,
42
+ top_k=500, # 上位{top_k}個の文章を保持
43
+ top_p=0.95, # 上位{top_p}%の単語から選択する。例)上位95%の単語から選んでくる
44
+ pad_token_id=tokenizer.pad_token_id,
45
+ bos_token_id=tokenizer.bos_token_id,
46
+ eos_token_id=tokenizer.eos_token_id,
47
+ #bad_word_ids=[[tokenizer.unk_token_id]],
48
+ num_return_sequences=num # 生成する文章の数
49
+ )
50
+ decoded = tokenizer.decode(output.tolist()[0])
51
+ return decoded
52
 
53
+ app = gr.Interface(fn=Chat, inputs=gr.Textbox(lines=3, placeholder="文章を入力してください"), outputs="text" , title="夏目漱石GPT")
54
  app.launch()