#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import gradio as gr import torch from transformers import BloomTokenizerFast, BloomForCausalLM from project_settings import project_path def get_args(): parser = argparse.ArgumentParser() parser.add_argument( '--trained_model_path', default=(project_path / "trained_models/bloom-1b4-sft").as_posix(), type=str, ) parser.add_argument('--device', default='auto', type=str) args = parser.parse_args() return args def main(): args = get_args() if args.device == 'auto': device = 'cuda' if torch.cuda.is_available() else 'cpu' else: device = args.device # pretrained model tokenizer = BloomTokenizerFast.from_pretrained(args.trained_model_path) model = BloomForCausalLM.from_pretrained(args.trained_model_path) description = """ FireflyBloom1b4 基于 [YeungNLP/bloom-1b4-zh](https://huggingface.co/YeungNLP/bloom-1b4-zh) 预训练模型, 基于 [YeungNLP/firefly-train-1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) 数据集, 训练的等同于 [YeungNLP/firefly-bloom-1b4](https://huggingface.co/YeungNLP/firefly-bloom-1b4) 的问答模型. 训练代码是自己编写的, 在 examples 里, 总共训练了 3 个 epoch. 感觉效果还可以. """ def fn(text: str, max_new_tokens: int = 200, top_p: float = 0.85, temperature: float = 0.35, repetition_penalty: float = 1.2 ): print(text) text = '{}'.format(text) input_ids = tokenizer(text, return_tensors="pt").input_ids input_ids = input_ids.to(device) outputs = model.generate(input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id ) rets = tokenizer.batch_decode(outputs) output = rets[0].strip().replace(text, "").replace('', "") print(output) return output demo = gr.Interface( fn=fn, inputs=[ gr.Text(label="text"), gr.Number(value=200, label="max_new_tokens"), gr.Slider(minimum=0, maximum=1, value=0.85, label="top_p"), gr.Slider(minimum=0, maximum=1, value=0.35, label="temperature"), gr.Number(value=1.2, label="repetition_penalty"), ], outputs=[gr.Text(label="output")], examples=[ [ "将下面句子翻译成现代文:\n石中央又生一树,高百余尺,条干偃阴为五色,翠叶如盘,花径尺余,色深碧,蕊深红,异香成烟,著物霏霏。", 200, 0.85, 0.35, 1.2 ], [ "实体识别: 1949年10月1日,人们在北京天安门广场参加开国大典。", 200, 0.85, 0.35, 1.2 ], [ "把这句话翻译成英文: 1949年10月1日,人们在北京天安门广场参加开国大典。", 200, 0.85, 0.35, 1.2 ], [ "晚上睡不着该怎么办. 请给点详细的介绍.", 200, 0.85, 0.35, 1.2 ], [ "将下面的句子翻译成文言文:结婚率下降, 离婚率暴增, 生育率下降, 人民焦虑迷茫, 到底是谁的错.", 200, 0.85, 0.35, 1.2 ], [ "对联:厌烟沿檐烟燕眼.", 200, 0.85, 0.35, 1.2 ], [ "写一首咏雪的古诗, 标题为 \"沁园春, 雪\".", 200, 0.85, 0.35, 1.2 ], ], examples_per_page=50, title="Firefly Bloom 1b4", description=description, ) demo.launch() return if __name__ == '__main__': main()