Update main.py
Browse files
main.py
CHANGED
@@ -38,7 +38,7 @@ def set_args():
|
|
38 |
# help='模型参数')
|
39 |
parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置')
|
40 |
parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径')
|
41 |
-
parser.add_argument('--vocab_path', default='
|
42 |
help='对话模型路径')
|
43 |
parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
|
44 |
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
|
@@ -78,7 +78,7 @@ def create_logger(args):
|
|
78 |
class Word_BERT(nn.Module):
|
79 |
def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2):
|
80 |
super(Word_BERT, self).__init__()
|
81 |
-
self.bert = BertModel.from_pretrained('
|
82 |
# self.bert_config = self.bert.config
|
83 |
self.out = nn.Sequential(
|
84 |
# nn.Linear(768,256),
|
|
|
38 |
# help='模型参数')
|
39 |
parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置')
|
40 |
parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径')
|
41 |
+
parser.add_argument('--vocab_path', default='./bert-base-zh\\vocab.txt', type=str, required=False,
|
42 |
help='对话模型路径')
|
43 |
parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
|
44 |
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
|
|
|
78 |
class Word_BERT(nn.Module):
|
79 |
def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2):
|
80 |
super(Word_BERT, self).__init__()
|
81 |
+
self.bert = BertModel.from_pretrained('./bert-base-zh')
|
82 |
# self.bert_config = self.bert.config
|
83 |
self.out = nn.Sequential(
|
84 |
# nn.Linear(768,256),
|