dmariko commited on
Commit
5963b0e
1 Parent(s): 3c5ec27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -48,10 +48,19 @@ def change_model_name(name):
48
 
49
 
50
  def generate(model, text):
51
- model.eval()
52
- input_ids = tokenizer.encode("AFA:{}".format(text), return_tensors="pt")
53
- outputs = model.generate(input_ids, max_length=200, num_beams=2, repetition_penalty=2.5, top_k=50, top_p=0.98, length_penalty=1.0, early_stopping=True)
54
- return tokenizer.decode(outputs[0])
 
 
 
 
 
 
 
 
 
55
 
56
 
57
 
 
48
 
49
 
50
  def generate(model, text):
51
+
52
+ if model_name != MODEL_NAME:
53
+ change_model_name(model_name)
54
+
55
+ tokenizer = MODEL_BUF["tokenizer"]
56
+ model = MODEL_BUF["model"]
57
+ config = MODEL_BUF["config"]
58
+
59
+ model.eval()
60
+ input_ids = tokenizer.encode("AFA:{}".format(text), return_tensors="pt")
61
+ outputs = model.generate(input_ids, max_length=200, num_beams=2, repetition_penalty=2.5, top_k=50, top_p=0.98, length_penalty=1.0, early_stopping=True)
62
+
63
+ return tokenizer.decode(outputs[0])
64
 
65
 
66