shibing624 commited on
Commit
0f6f16f
1 Parent(s): e8d4933

update demo.

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -15,7 +15,7 @@ model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese")
15
 
16
  def ai_text(text):
17
  with torch.no_grad():
18
- outputs = model(**tokenizer(text, padding=True, return_tensors='pt'))
19
 
20
  def get_errors(corrected_text, origin_text):
21
  sub_details = []
@@ -35,7 +35,7 @@ def ai_text(text):
35
  sub_details = sorted(sub_details, key=operator.itemgetter(2))
36
  return corrected_text, sub_details
37
 
38
- _text = tokenizer.decode(torch.argmax(outputs.logits, dim=-1), skip_special_tokens=True).replace(' ', '')
39
  corrected_text = _text[:len(text)]
40
  corrected_text, details = get_errors(corrected_text, text)
41
  print(text, ' => ', corrected_text, details)
 
15
 
16
  def ai_text(text):
17
  with torch.no_grad():
18
+ outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
19
 
20
  def get_errors(corrected_text, origin_text):
21
  sub_details = []
 
35
  sub_details = sorted(sub_details, key=operator.itemgetter(2))
36
  return corrected_text, sub_details
37
 
38
+ _text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
39
  corrected_text = _text[:len(text)]
40
  corrected_text, details = get_errors(corrected_text, text)
41
  print(text, ' => ', corrected_text, details)