| import os |
| import torch |
| import numpy as np |
| from torch import nn |
| from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast, AutoConfig |
|
|
| |
| TAG_COLS = ['Data', 'Action', 'Gain', 'Regu', 'Vague'] |
| PREDICTION_THRESHOLD = 0.5 |
|
|
| |
| |
| |
| class BertForMultiLabelClassification(BertPreTrainedModel): |
| """ |
| 基于 BERT 的多标签分类模型,使用 BCEWithLogitsLoss |
| """ |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
|
|
| |
| self.bert = BertModel(config) |
| |
| |
| classifier_dropout = config.hidden_dropout_prob |
| self.dropout = nn.Dropout(classifier_dropout) |
| |
| |
| self.classifier = nn.Linear(config.hidden_size, self.num_labels) |
|
|
| self.post_init() |
| |
| self.loss_fct = nn.BCEWithLogitsLoss() |
|
|
| def forward(self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| labels=None): |
| |
| outputs = self.bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| ) |
| |
| |
| pooled_output = outputs.pooler_output |
| pooled_output = self.dropout(pooled_output) |
| |
| |
| logits = self.classifier(pooled_output) |
|
|
| |
| return logits |
|
|
|
|
| |
| |
| |
| def predict_multilabel(checkpoint_path: str, tokenizer_path: str, text_to_predict: str): |
| """ |
| 加载模型检查点,对单个文本进行多标签预测。 |
| |
| Args: |
| checkpoint_path: BERT 模型检查点目录(包含 config.json, model.safetensors)。 |
| tokenizer_path: 分词器路径或名称。 |
| text_to_predict: 待预测的输入文本。 |
| |
| Returns: |
| 包含预测标签和概率的字典。 |
| """ |
| print(f"--- 1. 正在加载模型和分词器: {checkpoint_path} ---") |
| |
| try: |
| config = AutoConfig.from_pretrained(checkpoint_path) |
| |
| if config.num_labels != len(TAG_COLS): |
| |
| config.num_labels = len(TAG_COLS) |
| print(f"警告: 检查点配置的 num_labels 已从 {config.num_labels} 修正为 {len(TAG_COLS)}") |
|
|
| |
| tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path) |
| |
| |
| model = BertForMultiLabelClassification.from_pretrained( |
| checkpoint_path, |
| config=config |
| ) |
| except Exception as e: |
| print(f"加载模型或分词器失败,请检查路径中是否包含所有必需文件(如 model.safetensors, config.json, vocab.txt): {e}") |
| return None |
|
|
| model.eval() |
| |
| |
| inputs = tokenizer( |
| text_to_predict, |
| padding="max_length", |
| truncation=True, |
| max_length=512, |
| return_tensors="pt" |
| ) |
| |
| |
| with torch.no_grad(): |
| |
| outputs = model(**inputs) |
| logits = outputs.cpu().numpy() |
|
|
| |
| |
| probs = 1 / (1 + np.exp(-logits)) |
| |
| preds = (probs > PREDICTION_THRESHOLD).astype(int) |
|
|
| |
| result = {} |
| |
| |
| for i, tag in enumerate(TAG_COLS): |
| |
| is_predicted = preds[0][i] == 1 |
| probability = probs[0][i] |
| |
| result[tag] = { |
| "predicted": is_predicted, |
| "probability": float(f"{probability:.4f}") |
| } |
| |
| print("--- 5. 预测结果 ---") |
| |
| |
| predicted_tags = [tag for tag, info in result.items() if info["predicted"]] |
| |
| if predicted_tags: |
| print(f"预测标签类别: {predicted_tags}") |
| print(f"对应概率:") |
| for tag in predicted_tags: |
| print(f" - {tag}: {result[tag]['probability']}") |
| else: |
| print("未预测任何标签(所有标签概率均低于 0.5)。") |
| print(f"所有标签的最高概率: {max(p['probability'] for p in result.values()):.4f}") |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| |
| MODEL_CHECKPOINT = "/home/hsichen/part_time/BERT_finetune/outputs/finbert2_multilabel_model_finetuned_from_dapt/final" |
| TOKENIZER = 'valuesimplex-ai-lab/FinBERT2-base' |
| |
| SAMPLE_TEXT = "密切关注安全环保对原料市场的影响,提前落实应对预案;" |
| |
| |
| if not os.path.exists(MODEL_CHECKPOINT): |
| print(f"错误:模型检查点目录不存在: {MODEL_CHECKPOINT}") |
| else: |
| predict_multilabel(MODEL_CHECKPOINT,TOKENIZER, SAMPLE_TEXT) |