Update app.py
Browse files
app.py
CHANGED
@@ -27,9 +27,15 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=0.0, filter_value=-float('Inf'
|
|
27 |
return logits
|
28 |
|
29 |
def generate(title, context, max_len):
|
|
|
|
|
|
|
30 |
title_ids = tokenizer.encode(title, add_special_tokens=False)
|
31 |
context_ids = tokenizer.encode(context, add_special_tokens=False)
|
|
|
|
|
32 |
input_ids = title_ids + [sep_id] + context_ids
|
|
|
33 |
cur_len = len(input_ids)
|
34 |
input_len = cur_len
|
35 |
last_token_id = input_ids[-1]
|
@@ -60,7 +66,6 @@ if __name__ == '__main__':
|
|
60 |
eod_id = tokenizer.convert_tokens_to_ids("<eod>")
|
61 |
sep_id = tokenizer.sep_token_id
|
62 |
unk_id = tokenizer.unk_token_id
|
63 |
-
|
64 |
|
65 |
|
66 |
gr.Interface(
|
@@ -71,4 +76,4 @@ if __name__ == '__main__':
|
|
71 |
"number"
|
72 |
],
|
73 |
outputs=gr.Textbox(lines=15, placeholder="AI生成的文本显示在这里。",label="生成的文本")
|
74 |
-
).launch()
|
|
|
27 |
return logits
|
28 |
|
29 |
def generate(title, context, max_len):
|
30 |
+
|
31 |
+
# input_ids.extend( tokenizer.encode(input_text + "-", add_special_tokens=False) )
|
32 |
+
|
33 |
title_ids = tokenizer.encode(title, add_special_tokens=False)
|
34 |
context_ids = tokenizer.encode(context, add_special_tokens=False)
|
35 |
+
print(title_ids,context_ids)
|
36 |
+
|
37 |
input_ids = title_ids + [sep_id] + context_ids
|
38 |
+
|
39 |
cur_len = len(input_ids)
|
40 |
input_len = cur_len
|
41 |
last_token_id = input_ids[-1]
|
|
|
66 |
eod_id = tokenizer.convert_tokens_to_ids("<eod>")
|
67 |
sep_id = tokenizer.sep_token_id
|
68 |
unk_id = tokenizer.unk_token_id
|
|
|
69 |
|
70 |
|
71 |
gr.Interface(
|
|
|
76 |
"number"
|
77 |
],
|
78 |
outputs=gr.Textbox(lines=15, placeholder="AI生成的文本显示在这里。",label="生成的文本")
|
79 |
+
).launch()
|