Update app.py
Browse files
app.py
CHANGED
@@ -9,19 +9,20 @@ def loadModels():
|
|
9 |
_tokenizer = PreTrainedTokenizerFast.from_pretrained(repository)
|
10 |
|
11 |
print("Loaded :)")
|
12 |
-
|
13 |
return _model, _tokenizer
|
14 |
-
|
15 |
model, tokenizer = loadModels()
|
16 |
|
17 |
lit.title("성경말투 생성기")
|
18 |
lit.caption("적당한 길이의 한 문장을 넣었을 때 가장 좋은 결과가 나옵니다.")
|
19 |
lit.caption("https://github.com/rycont/kobart-biblify")
|
20 |
|
21 |
-
text_input = lit.text_area("문장 입력")
|
22 |
-
|
23 |
MAX_LENGTH = 128
|
24 |
|
|
|
|
|
|
|
|
|
25 |
def biblifyWithBeams(beam, tokens, attention_mask):
|
26 |
generated = model.generate(
|
27 |
input_ids = torch.Tensor([ tokens ]).to(torch.int64),
|
@@ -30,8 +31,8 @@ def biblifyWithBeams(beam, tokens, attention_mask):
|
|
30 |
max_length = MAX_LENGTH,
|
31 |
eos_token_id=tokenizer.eos_token_id,
|
32 |
bad_words_ids=[[tokenizer.unk_token_id]]
|
33 |
-
|
34 |
-
|
35 |
return tokenizer.decode(
|
36 |
generated,
|
37 |
).replace('<s>', '').replace('</s>', '')
|
@@ -40,9 +41,10 @@ if len(text_input.strip()) > 0:
|
|
40 |
print(text_input)
|
41 |
|
42 |
text_input = "<s>" + text_input + "</s>"
|
43 |
-
tokens = tokenizer.encode(text_input)
|
44 |
|
|
|
45 |
tokenLength = len(tokens)
|
|
|
46 |
attentionMasks = [ 1 ] * tokenLength + [ 0 ] * (MAX_LENGTH - tokenLength)
|
47 |
tokens = tokens + [ tokenizer.pad_token_id ] * (MAX_LENGTH - tokenLength)
|
48 |
|
@@ -53,17 +55,16 @@ if len(text_input.strip()) > 0:
|
|
53 |
i + 1,
|
54 |
tokens,
|
55 |
attentionMasks
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
9 |
_tokenizer = PreTrainedTokenizerFast.from_pretrained(repository)
|
10 |
|
11 |
print("Loaded :)")
|
|
|
12 |
return _model, _tokenizer
|
13 |
+
|
14 |
model, tokenizer = loadModels()
|
15 |
|
16 |
lit.title("성경말투 생성기")
|
17 |
lit.caption("적당한 길이의 한 문장을 넣었을 때 가장 좋은 결과가 나옵니다.")
|
18 |
lit.caption("https://github.com/rycont/kobart-biblify")
|
19 |
|
|
|
|
|
20 |
MAX_LENGTH = 128
|
21 |
|
22 |
+
with lit.form("GEN"):
|
23 |
+
text_input = lit.text_area("문장 입력")
|
24 |
+
submitted = lit.form_submit_button("생성")
|
25 |
+
|
26 |
def biblifyWithBeams(beam, tokens, attention_mask):
|
27 |
generated = model.generate(
|
28 |
input_ids = torch.Tensor([ tokens ]).to(torch.int64),
|
|
|
31 |
max_length = MAX_LENGTH,
|
32 |
eos_token_id=tokenizer.eos_token_id,
|
33 |
bad_words_ids=[[tokenizer.unk_token_id]]
|
34 |
+
)[0]
|
35 |
+
|
36 |
return tokenizer.decode(
|
37 |
generated,
|
38 |
).replace('<s>', '').replace('</s>', '')
|
|
|
41 |
print(text_input)
|
42 |
|
43 |
text_input = "<s>" + text_input + "</s>"
|
|
|
44 |
|
45 |
+
tokens = tokenizer.encode(text_input)
|
46 |
tokenLength = len(tokens)
|
47 |
+
|
48 |
attentionMasks = [ 1 ] * tokenLength + [ 0 ] * (MAX_LENGTH - tokenLength)
|
49 |
tokens = tokens + [ tokenizer.pad_token_id ] * (MAX_LENGTH - tokenLength)
|
50 |
|
|
|
55 |
i + 1,
|
56 |
tokens,
|
57 |
attentionMasks
|
58 |
+
)
|
59 |
+
|
60 |
+
if generated in results:
|
61 |
+
print("중복됨")
|
62 |
+
continue
|
63 |
+
|
64 |
+
results.append(generated)
|
65 |
+
|
66 |
+
with lit.expander(str(len(results)) + "번째 결과 (" + str(i +1) + ")", True):
|
67 |
+
lit.write(generated)
|
68 |
+
print(generated)
|
69 |
+
|
70 |
+
lit.caption("및 " + str(5 - len(results)) + " 개의 중복된 결과")
|
|