FireflyBloom1b4 / main.py
qgyd2021's picture
[20230821121100]
f6ff4fa
raw
history blame
4.18 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import gradio as gr
import torch
from transformers import BloomTokenizerFast, BloomForCausalLM
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--trained_model_path',
default=(project_path / "trained_models/bloom-1b4-sft").as_posix(),
type=str,
)
parser.add_argument('--device', default='auto', type=str)
args = parser.parse_args()
return args
def main():
args = get_args()
if args.device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
device = args.device
# pretrained model
tokenizer = BloomTokenizerFast.from_pretrained(args.trained_model_path)
model = BloomForCausalLM.from_pretrained(args.trained_model_path)
description = """
FireflyBloom1b4
基于 [YeungNLP/bloom-1b4-zh](https://huggingface.co/YeungNLP/bloom-1b4-zh) 预训练模型,
基于 [YeungNLP/firefly-train-1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) 数据集,
训练的等同于 [YeungNLP/firefly-bloom-1b4](https://huggingface.co/YeungNLP/firefly-bloom-1b4) 的问答模型.
训练代码是自己编写的, 在 examples 里, 总共训练了 3 个 epoch. 感觉效果还可以.
"""
def fn(text: str,
max_new_tokens: int = 200,
top_p: float = 0.85,
temperature: float = 0.35,
repetition_penalty: float = 1.2
):
print(text)
text = '<s>{}</s></s>'.format(text)
input_ids = tokenizer(text, return_tensors="pt").input_ids
input_ids = input_ids.to(device)
outputs = model.generate(input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id
)
rets = tokenizer.batch_decode(outputs)
output = rets[0].strip().replace(text, "").replace('</s>', "")
print(output)
return output
demo = gr.Interface(
fn=fn,
inputs=[
gr.Text(label="text"),
gr.Number(value=200, label="max_new_tokens"),
gr.Slider(minimum=0, maximum=1, value=0.85, label="top_p"),
gr.Slider(minimum=0, maximum=1, value=0.35, label="temperature"),
gr.Number(value=1.2, label="repetition_penalty"),
],
outputs=[gr.Text(label="output")],
examples=[
[
"将下面句子翻译成现代文:\n石中央又生一树,高百余尺,条干偃阴为五色,翠叶如盘,花径尺余,色深碧,蕊深红,异香成烟,著物霏霏。",
200, 0.85, 0.35, 1.2
],
[
"实体识别: 1949年10月1日,人们在北京天安门广场参加开国大典。",
200, 0.85, 0.35, 1.2
],
[
"把这句话翻译成英文: 1949年10月1日,人们在北京天安门广场参加开国大典。",
200, 0.85, 0.35, 1.2
],
[
"晚上睡不着该怎么办. 请给点详细的介绍.",
200, 0.85, 0.35, 1.2
],
[
"将下面的句子翻译成文言文:结婚率下降, 离婚率暴增, 生育率下降, 人民焦虑迷茫, 到底是谁的错.",
200, 0.85, 0.35, 1.2
],
[
"对联:厌烟沿檐烟燕眼.",
200, 0.85, 0.35, 1.2
],
[
"写一首咏雪的古诗, 标题为 \"沁园春, 雪\".",
200, 0.85, 0.35, 1.2
],
],
examples_per_page=50,
title="Firefly Bloom 1b4",
description=description,
)
demo.launch()
return
if __name__ == '__main__':
main()