Spaces:
Runtime error
Runtime error
Commit
·
99ac71e
1
Parent(s):
f4f6d4a
app.py修正
Browse files
app.py
CHANGED
@@ -22,13 +22,33 @@ import torch
|
|
22 |
|
23 |
# トークナイザーとモデルの準備
|
24 |
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
|
25 |
-
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
|
|
|
26 |
|
27 |
# 推論の実行
|
|
|
|
|
|
|
|
|
28 |
def Chat(prompt):
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
app = gr.Interface(fn=Chat, inputs=gr.Textbox(lines=3, placeholder="文章を入力してください"), outputs="text" , title="
|
34 |
app.launch()
|
|
|
22 |
|
23 |
# トークナイザーとモデルの準備
|
24 |
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
|
25 |
+
# model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
|
26 |
+
model = AutoModelForCausalLM.from_pretrained("output/")
|
27 |
|
28 |
# 推論の実行
|
29 |
+
#def Chat(prompt):
|
30 |
+
# input = tokenizer.encode(prompt, return_tensors="pt")
|
31 |
+
# output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=5)
|
32 |
+
# return tokenizer.batch_decode(output)
|
33 |
def Chat(prompt):
|
34 |
+
num = 3
|
35 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt",add_special_tokens=False).to(device)
|
36 |
+
#with torch.no_grad():
|
37 |
+
output = model.generate(
|
38 |
+
input_ids,
|
39 |
+
max_length=300, # 最長の文章長
|
40 |
+
min_length=100, # 最短の文章長
|
41 |
+
do_sample=True,
|
42 |
+
top_k=500, # 上位{top_k}個の文章を保持
|
43 |
+
top_p=0.95, # 上位{top_p}%の単語から選択する。例)上位95%の単語から選んでくる
|
44 |
+
pad_token_id=tokenizer.pad_token_id,
|
45 |
+
bos_token_id=tokenizer.bos_token_id,
|
46 |
+
eos_token_id=tokenizer.eos_token_id,
|
47 |
+
#bad_word_ids=[[tokenizer.unk_token_id]],
|
48 |
+
num_return_sequences=num # 生成する文章の数
|
49 |
+
)
|
50 |
+
decoded = tokenizer.decode(output.tolist()[0])
|
51 |
+
return decoded
|
52 |
|
53 |
+
app = gr.Interface(fn=Chat, inputs=gr.Textbox(lines=3, placeholder="文章を入力してください"), outputs="text" , title="夏目漱石GPT")
|
54 |
app.launch()
|