Spaces:
Runtime error
Runtime error
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
import gradio as gr | |
import torch | |
from transformers import BloomTokenizerFast, BloomForCausalLM | |
from project_settings import project_path | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--trained_model_path', | |
default=(project_path / "trained_models/bloom-1b4-sft").as_posix(), | |
type=str, | |
) | |
parser.add_argument('--device', default='auto', type=str) | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = get_args() | |
if args.device == 'auto': | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
else: | |
device = args.device | |
# pretrained model | |
tokenizer = BloomTokenizerFast.from_pretrained(args.trained_model_path) | |
model = BloomForCausalLM.from_pretrained(args.trained_model_path) | |
description = """ | |
FireflyBloom1b4 | |
基于 [YeungNLP/bloom-1b4-zh](https://huggingface.co/YeungNLP/bloom-1b4-zh) 预训练模型, | |
基于 [YeungNLP/firefly-train-1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) 数据集, | |
训练的等同于 [YeungNLP/firefly-bloom-1b4](https://huggingface.co/YeungNLP/firefly-bloom-1b4) 的问答模型. | |
训练代码是自己编写的, 在 examples 里, 总共训练了 3 个 epoch. 感觉效果还可以. | |
""" | |
def fn(text: str, | |
max_new_tokens: int = 200, | |
top_p: float = 0.85, | |
temperature: float = 0.35, | |
repetition_penalty: float = 1.2 | |
): | |
print(text) | |
text = '<s>{}</s></s>'.format(text) | |
input_ids = tokenizer(text, return_tensors="pt").input_ids | |
input_ids = input_ids.to(device) | |
outputs = model.generate(input_ids, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
rets = tokenizer.batch_decode(outputs) | |
output = rets[0].strip().replace(text, "").replace('</s>', "") | |
print(output) | |
return output | |
demo = gr.Interface( | |
fn=fn, | |
inputs=[ | |
gr.Text(label="text"), | |
gr.Number(value=200, label="max_new_tokens"), | |
gr.Slider(minimum=0, maximum=1, value=0.85, label="top_p"), | |
gr.Slider(minimum=0, maximum=1, value=0.35, label="temperature"), | |
gr.Number(value=1.2, label="repetition_penalty"), | |
], | |
outputs=[gr.Text(label="output")], | |
examples=[ | |
[ | |
"将下面句子翻译成现代文:\n石中央又生一树,高百余尺,条干偃阴为五色,翠叶如盘,花径尺余,色深碧,蕊深红,异香成烟,著物霏霏。", | |
200, 0.85, 0.35, 1.2 | |
], | |
[ | |
"实体识别: 1949年10月1日,人们在北京天安门广场参加开国大典。", | |
200, 0.85, 0.35, 1.2 | |
], | |
[ | |
"把这句话翻译成英文: 1949年10月1日,人们在北京天安门广场参加开国大典。", | |
200, 0.85, 0.35, 1.2 | |
], | |
[ | |
"晚上睡不着该怎么办. 请给点详细的介绍.", | |
200, 0.85, 0.35, 1.2 | |
], | |
[ | |
"将下面的句子翻译成文言文:结婚率下降, 离婚率暴增, 生育率下降, 人民焦虑迷茫, 到底是谁的错.", | |
200, 0.85, 0.35, 1.2 | |
], | |
[ | |
"对联:厌烟沿檐烟燕眼.", | |
200, 0.85, 0.35, 1.2 | |
], | |
[ | |
"写一首咏雪的古诗, 标题为 \"沁园春, 雪\".", | |
200, 0.85, 0.35, 1.2 | |
], | |
], | |
examples_per_page=50, | |
title="Firefly Bloom 1b4", | |
description=description, | |
) | |
demo.launch() | |
return | |
if __name__ == '__main__': | |
main() | |