|
--- |
|
language: zh |
|
datasets: couplet |
|
inference: |
|
parameters: |
|
max_length: 68 |
|
do_sample: True |
|
widget: |
|
- text: "燕子归来,问昔日雕梁何处" |
|
example_title: "对联1" |
|
- text: "笑取琴书温旧梦" |
|
example_title: "对联2" |
|
- text: "煦煦春风,吹暖五湖四海" |
|
example_title: "对联3" |
|
--- |
|
|
|
|
|
# 对联 |
|
|
|
## Model description |
|
|
|
对联AI生成,给出上联,生成下联。 |
|
|
|
## How to use |
|
使用 pipeline 调用模型: |
|
|
|
```python |
|
>>> task_prefix = "" |
|
>>> sentence = task_prefix+"国色天香,姹紫嫣红,碧水青云欣共赏" |
|
>>> model_output_dir='couplet-hel-mt5-finetuning/' |
|
>>> from transformers import pipeline |
|
>>> translation = pipeline("translation", model=model_output_dir) |
|
>>> print(translation(sentence,max_length=28)) |
|
[{'translation_text': '月圆花好,良辰美景,良辰美景喜相逢'}] |
|
|
|
``` |
|
Here is how to use this model to get the features of a given text in PyTorch: |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModel |
|
tokenizer = AutoTokenizer.from_pretrained("supermy/couplet-helsinki") |
|
model = AutoModel.from_pretrained("supermy/couplet-helsinki") |
|
``` |
|
|
|
|
|
|
|
## Training data |
|
|
|
此数据集基于couplet-dataset的70w条数据集,在此基础上利用敏感词词库对数据进行了过滤,删除了低俗或敏感的内容,删除后剩余约74w条对联数据。 |
|
|
|
## 统计信息 |
|
|
|
``` |
|
|
|
``` |
|
|
|
## Training procedure |
|
|
|
模型:[Helsinki-NLP/opus-mt-zh-en](https://huggingface.co/Helsinki-NLP/opus-mt-zh-en) |
|
训练环境:英伟达16G显卡 |
|
|
|
mt5分词:"vocab_size"=50000 |
|
``` |
|
[INFO|trainer.py:1634] 2022-12-13 06:27:25,113 >> ***** Running training ***** |
|
[INFO|trainer.py:1635] 2022-12-13 06:27:25,113 >> Num examples = 741096 |
|
[INFO|trainer.py:1636] 2022-12-13 06:27:25,113 >> Num Epochs = 36 |
|
[INFO|trainer.py:1637] 2022-12-13 06:27:25,113 >> Instantaneous batch size per device = 256 |
|
[INFO|trainer.py:1638] 2022-12-13 06:27:25,113 >> Total train batch size (w. parallel, distributed & accumulation) = 256 |
|
[INFO|trainer.py:1639] 2022-12-13 06:27:25,114 >> Gradient Accumulation steps = 1 |
|
[INFO|trainer.py:1640] 2022-12-13 06:27:25,114 >> Total optimization steps = 104220 |
|
[INFO|trainer.py:1642] 2022-12-13 06:27:25,114 >> Number of trainable parameters = 77419008 |
|
[INFO|trainer.py:1663] 2022-12-13 06:27:25,115 >> Continuing training from checkpoint, will skip to saved global_step |
|
[INFO|trainer.py:1664] 2022-12-13 06:27:25,115 >> Continuing training from epoch 2 |
|
[INFO|trainer.py:1665] 2022-12-13 06:27:25,115 >> Continuing training from global step 7500 |
|
|
|
{'loss': 5.5206, 'learning_rate': 4.616340433697947e-05, 'epoch': 2.76} |
|
{'loss': 5.4737, 'learning_rate': 4.5924006908462866e-05, 'epoch': 2.94} |
|
{'loss': 5.382, 'learning_rate': 4.5684609479946274e-05, 'epoch': 3.11} |
|
{'loss': 5.34, 'learning_rate': 4.544473229706391e-05, 'epoch': 3.28} |
|
{'loss': 5.3154, 'learning_rate': 4.520485511418154e-05, 'epoch': 3.45} |
|
...... |
|
...... |
|
...... |
|
{'loss': 3.3099, 'learning_rate': 3.650930723469584e-07, 'epoch': 35.75} |
|
{'loss': 3.3077, 'learning_rate': 1.2521588946459413e-07, 'epoch': 35.92} |
|
{'train_runtime': 41498.9079, 'train_samples_per_second': 642.895, 'train_steps_per_second': 2.511, 'train_loss': 3.675059686432734, 'epoch': 36.0} |
|
***** train metrics ***** |
|
epoch = 36.0 |
|
train_loss = 3.6751 |
|
train_runtime = 11:31:38.90 |
|
train_samples = 741096 |
|
train_samples_per_second = 642.895 |
|
train_steps_per_second = 2.511 |
|
12/13/2022 17:59:05 - INFO - __main__ - *** Evaluate *** |
|
[INFO|trainer.py:2944] 2022-12-13 17:59:05,707 >> ***** Running Evaluation ***** |
|
[INFO|trainer.py:2946] 2022-12-13 17:59:05,708 >> Num examples = 3834 |
|
[INFO|trainer.py:2949] 2022-12-13 17:59:05,708 >> Batch size = 256 |
|
100%|██████████| 15/15 [03:25<00:00, 13.69s/it] |
|
[INFO|modelcard.py:449] 2022-12-13 18:02:46,984 >> Dropping the following result as it does not have all the necessary fields: |
|
{'task': {'name': 'Translation', 'type': 'translation'}, 'metrics': [{'name': 'Bleu', 'type': 'bleu', 'value': 3.7831}]} |
|
***** eval metrics ***** |
|
epoch = 36.0 |
|
eval_bleu = 3.7831 |
|
eval_gen_len = 63.0 |
|
eval_loss = 4.5035 |
|
eval_runtime = 0:03:40.09 |
|
eval_samples = 3834 |
|
eval_samples_per_second = 17.419 |
|
eval_steps_per_second = 0.068 |
|
|
|
``` |
|
|