Update main2.py
Browse files
main2.py
CHANGED
@@ -25,9 +25,8 @@ def set_args():
|
|
25 |
# help='模型参数')
|
26 |
parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置')
|
27 |
parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径')
|
28 |
-
parser.add_argument('--vocab_path', default='
|
29 |
help='对话模型路径')
|
30 |
-
parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
|
31 |
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
|
32 |
help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
|
33 |
# parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
|
@@ -65,7 +64,7 @@ def create_logger(args):
|
|
65 |
class Word_BERT(nn.Module):
|
66 |
def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2):
|
67 |
super(Word_BERT, self).__init__()
|
68 |
-
self.bert = BertModel.from_pretrained('
|
69 |
# self.bert_config = self.bert.config
|
70 |
self.out = nn.Sequential(
|
71 |
# nn.Linear(768,256),
|
|
|
25 |
# help='模型参数')
|
26 |
parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置')
|
27 |
parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径')
|
28 |
+
parser.add_argument('--vocab_path', default='/app/bert-base-zh/vocab.txt', type=str, required=False,
|
29 |
help='对话模型路径')
|
|
|
30 |
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
|
31 |
help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
|
32 |
# parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
|
|
|
64 |
class Word_BERT(nn.Module):
|
65 |
def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2):
|
66 |
super(Word_BERT, self).__init__()
|
67 |
+
self.bert = BertModel.from_pretrained('/app/bert-base-zh')
|
68 |
# self.bert_config = self.bert.config
|
69 |
self.out = nn.Sequential(
|
70 |
# nn.Linear(768,256),
|