Fujisaki / app.py
ljsabc's picture
Refined some parameters.
31c18b1
raw
history blame
3.26 kB
# -*- coding: utf-8 -*-
"""Fujisaki_CPU.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1Damnr0Ha4zZAlKFvne9cu76uuElLNYus
李萌萌的电子骨灰盒
----
这是一个通过ChatGLM模型训练的李萌萌的数字分身,你可以在问题栏目填入内容,或者什么都不填,来观察李萌萌到底会说些什么。
T4级别的GPU已经可以很胜任这个任务了。
### 安装依赖
"""
from modeling_chatglm import ChatGLMForConditionalGeneration
import torch
import sys
from transformers import AutoTokenizer, GenerationConfig
model = ChatGLMForConditionalGeneration.from_pretrained("THUDM/chatglm-6b").float()
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
peft_path = 'ljsabc/Fujisaki_GLM' # change it to your own
model = PeftModel.from_pretrained(
model,
peft_path,
torch_dtype=torch.float,
)
# We have to use full precision, as some tokens are >65535
model.eval()
torch.set_default_tensor_type(torch.FloatTensor)
def evaluate(context, temperature, top_p, top_k):
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
#repetition_penalty=1.1,
num_beams=1,
do_sample=True,
)
with torch.no_grad():
input_text = f"Context: {context}Answer: "
ids = tokenizer.encode(input_text)
input_ids = torch.LongTensor([ids]).to('cpu')
out = model.generate(
input_ids=input_ids,
max_length=160,
generation_config=generation_config
)
out_text = tokenizer.decode(out[0]).split("Answer: ")[1]
return out_text
import gradio as gr
gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2, label="问题", placeholder="最近过得怎么样?",
info="可以在这里输入你的问题。也可以什么都不填写生成随机数据。"
),
#gr.components.Textbox(lines=2, label="Input", placeholder="none"),
gr.components.Slider(minimum=0, maximum=1.1, value=1.0, label="Temperature",
info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。"),
gr.components.Slider(minimum=0.5, maximum=1.0, value=0.98, label="Top p",
info="top-p参数,只输出前p>top-p的文字,建议不要修改。"),
gr.components.Slider(minimum=1, maximum=200, step=1, value=40, label="Top k",
info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。"),
],
outputs=[
gr.inputs.Textbox(
lines=5,
label="Output",
)
],
title="李萌萌(Alter Ego)",
description="这是一个通过ChatGLM模型训练的李萌萌的数字分身,你可以在问题栏目填入内容,或者什么都不填,来观察李萌萌到底会说些什么。因为是在CPU上进行运行,速度会比较慢。",
).launch()