Update app.py
Browse files
app.py
CHANGED
@@ -48,10 +48,19 @@ def change_model_name(name):
|
|
48 |
|
49 |
|
50 |
def generate(model, text):
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
|
|
|
48 |
|
49 |
|
50 |
def generate(model, text):
|
51 |
+
|
52 |
+
if model_name != MODEL_NAME:
|
53 |
+
change_model_name(model_name)
|
54 |
+
|
55 |
+
tokenizer = MODEL_BUF["tokenizer"]
|
56 |
+
model = MODEL_BUF["model"]
|
57 |
+
config = MODEL_BUF["config"]
|
58 |
+
|
59 |
+
model.eval()
|
60 |
+
input_ids = tokenizer.encode("AFA:{}".format(text), return_tensors="pt")
|
61 |
+
outputs = model.generate(input_ids, max_length=200, num_beams=2, repetition_penalty=2.5, top_k=50, top_p=0.98, length_penalty=1.0, early_stopping=True)
|
62 |
+
|
63 |
+
return tokenizer.decode(outputs[0])
|
64 |
|
65 |
|
66 |
|