Halo Master commited on
Commit
1f5f1b1
1 Parent(s): f0eb52a

large model

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -19,26 +19,27 @@ model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1")
19
 
20
  device = torch.device('cpu')
21
  model.to(device)
22
- def preprocess(text):
23
- return text.replace("\n", "_")
24
 
 
 
 
 
25
  def postprocess(text):
26
- return text.replace("_", "\n")
27
-
28
- def answer(text, sample=False, top_p=0.6):
29
  '''sample:是否抽样。生成任务,可以设置为True;
30
  top_p:0-1之间,生成的内容越多样'''
31
  text = preprocess(text)
32
  encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=768, return_tensors="pt").to(device)
33
  if not sample:
34
- out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=128, num_beams=4, length_penalty=0.6)
35
  else:
36
- out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=128, do_sample=True, top_p=top_p)
37
  out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
38
  return postprocess(out_text[0])
39
 
40
 
41
-
42
  #iface = gr.Interface(fn=answer, inputs="text", outputs="text")
43
  examples = [
44
  ["""摘要这段话:
 
19
 
20
  device = torch.device('cpu')
21
  model.to(device)
 
 
22
 
23
+ def preprocess(text):
24
+ text = text.replace("\n", "\\n").replace("\t", "\\t")
25
+ return text
26
+
27
  def postprocess(text):
28
+ return text.replace("\\n", "\n").replace("\\t", "\t")
29
+
30
+ def answer(text, sample=True, top_p=1, temperature=0.7):
31
  '''sample:是否抽样。生成任务,可以设置为True;
32
  top_p:0-1之间,生成的内容越多样'''
33
  text = preprocess(text)
34
  encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=768, return_tensors="pt").to(device)
35
  if not sample:
36
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, num_beams=1, length_penalty=0.6)
37
  else:
38
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=3)
39
  out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
40
  return postprocess(out_text[0])
41
 
42
 
 
43
  #iface = gr.Interface(fn=answer, inputs="text", outputs="text")
44
  examples = [
45
  ["""摘要这段话: