Edit model card

此模型的作用是对输入的简体七言律诗进行风格上的分类,详情见 https://mp.weixin.qq.com/s/P8FVCkI8-anDuLWQIAgs2w

使用方法如下:

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json
import torch.nn.functional as F
from zhconv import convert
import re

model_path = "qixun/qilv_classify"

# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

# 如果GPU可用,将模型移动到GPU
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)

# 加载标签映射关系,label_mapping.json需要根据本机情况修改
with open("label_mapping.json", "r", encoding="utf-8") as f:
    label_mapping = json.load(f)


def classify_text(text):

    text = convert(text, 'zh-cn')
    # 去掉空格和换行
    text = text.replace(" ", "").replace("\n", "")

    # 检查文本长度是否为56个字符
    if len(text) != 64:
      return "请输入一首带标点的七言律诗"
    
    unique_characters = set(re.findall(r'[\u4e00-\u9fff]', text))
    if len(unique_characters) < 30:
        return "请输入一首正常的七言律诗"
    
    # 准备输入数据
    inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)

    # 如GPU可用,将输入数据移动到GPU
    #inputs = {key: value.to(device) for key, value in inputs.items()}

    # 模型推断
    with torch.no_grad():
        outputs = model(**inputs)

    # 获取预测结果
    logits = outputs.logits

    # 计算每个类别的概率
    probabilities = F.softmax(logits, dim=-1)

    # 获取概率最高的三个分类及其概率
    top_k = 3
    top_probs, top_indices = torch.topk(probabilities, top_k, dim=-1)

    # 将预测结果转换为标签并附上概率
    results = []
    for j in range(top_k):
    label = label_mapping[str(top_indices[0][j].item())]
    prob = top_probs[0][j].item()
    results.append((label, prob))

    # 将结果格式化为字符串
    result_str = "文本: {}\n".format(text)
    for label, prob in results:
        result_str += "分类: {}, 概率: {:.4f}\n".format(label, prob)
    
    return result_str

# 示例调用
text = "胎禽消息渺难知,小萼妆容故故迟。城郭渐随寒碧敛,湖山刚与晚阴宜,再来恐或成孤往,此去何由问所之。坐对空亭喧冻雀,可堪暝色向人垂。"
result = classify_text(text)
print(result)

也可以直接在huggingface里输入一首加标点为64字符的简体七言律诗进行测试,label_mapping.json内容为:

{
    "0": "中唐",
    "1": "乱码",
    "2": "冲塔",
    "3": "同光",
    "4": "复兴",
    "5": "实验",
    "6": "晚唐",
    "7": "江西",
    "8": "浙",
    "9": "浣花",
    "10": "理学",
    "11": "盛唐",
    "12": "艳体",
    "13": "诗界xx",
    "14": "赣",
    "15": "闽"
}

大家自行转换。

Downloads last month
3
Safetensors
Model size
103M params
Tensor type
F32
·