qianmuuq commited on
Commit
419df27
1 Parent(s): 99bc354

Update main2.py

Browse files
Files changed (1) hide show
  1. main2.py +2 -3
main2.py CHANGED
@@ -25,9 +25,8 @@ def set_args():
25
  # help='模型参数')
26
  parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置')
27
  parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径')
28
- parser.add_argument('--vocab_path', default='./bert-base-zh\\vocab.txt', type=str, required=False,
29
  help='对话模型路径')
30
- parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
31
  parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
32
  help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
33
  # parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
@@ -65,7 +64,7 @@ def create_logger(args):
65
  class Word_BERT(nn.Module):
66
  def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2):
67
  super(Word_BERT, self).__init__()
68
- self.bert = BertModel.from_pretrained('./bert-base-zh')
69
  # self.bert_config = self.bert.config
70
  self.out = nn.Sequential(
71
  # nn.Linear(768,256),
 
25
  # help='模型参数')
26
  parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置')
27
  parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径')
28
+ parser.add_argument('--vocab_path', default='/app/bert-base-zh/vocab.txt', type=str, required=False,
29
  help='对话模型路径')
 
30
  parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
31
  help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
32
  # parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
 
64
  class Word_BERT(nn.Module):
65
  def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2):
66
  super(Word_BERT, self).__init__()
67
+ self.bert = BertModel.from_pretrained('/app/bert-base-zh')
68
  # self.bert_config = self.bert.config
69
  self.out = nn.Sequential(
70
  # nn.Linear(768,256),