Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""Fujisaki_CPU.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1Damnr0Ha4zZAlKFvne9cu76uuElLNYus | |
李萌萌的电子骨灰盒 | |
---- | |
这是一个通过ChatGLM模型训练的李萌萌的数字分身,你可以在问题栏目填入内容,或者什么都不填,来观察李萌萌到底会说些什么。 | |
T4级别的GPU已经可以很胜任这个任务了。 | |
### 安装依赖 | |
""" | |
from modeling_chatglm import ChatGLMForConditionalGeneration | |
import torch | |
import sys | |
from transformers import AutoTokenizer, GenerationConfig | |
model = ChatGLMForConditionalGeneration.from_pretrained("THUDM/chatglm-6b").float() | |
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) | |
from peft import get_peft_model, LoraConfig, TaskType, PeftModel | |
peft_path = 'ljsabc/Fujisaki_GLM' # change it to your own | |
model = PeftModel.from_pretrained( | |
model, | |
peft_path, | |
torch_dtype=torch.float, | |
) | |
# We have to use full precision, as some tokens are >65535 | |
model.eval() | |
torch.set_default_tensor_type(torch.FloatTensor) | |
def evaluate(context, temperature, top_p, top_k): | |
generation_config = GenerationConfig( | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
#repetition_penalty=1.1, | |
num_beams=1, | |
do_sample=True, | |
) | |
with torch.no_grad(): | |
input_text = f"Context: {context}Answer: " | |
ids = tokenizer.encode(input_text) | |
input_ids = torch.LongTensor([ids]).to('cpu') | |
out = model.generate( | |
input_ids=input_ids, | |
max_length=160, | |
generation_config=generation_config | |
) | |
out_text = tokenizer.decode(out[0]).split("Answer: ")[1] | |
return out_text | |
import gradio as gr | |
gr.Interface( | |
fn=evaluate, | |
inputs=[ | |
gr.components.Textbox( | |
lines=2, label="问题", placeholder="最近过得怎么样?", | |
info="可以在这里输入你的问题。也可以什么都不填写生成随机数据。" | |
), | |
#gr.components.Textbox(lines=2, label="Input", placeholder="none"), | |
gr.components.Slider(minimum=0, maximum=1.1, value=1.0, label="Temperature", | |
info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。"), | |
gr.components.Slider(minimum=0.5, maximum=1.0, value=0.98, label="Top p", | |
info="top-p参数,只输出前p>top-p的文字,建议不要修改。"), | |
gr.components.Slider(minimum=1, maximum=200, step=1, value=40, label="Top k", | |
info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。"), | |
], | |
outputs=[ | |
gr.inputs.Textbox( | |
lines=5, | |
label="Output", | |
) | |
], | |
title="李萌萌(Alter Ego)", | |
description="这是一个通过ChatGLM模型训练的李萌萌的数字分身,你可以在问题栏目填入内容,或者什么都不填,来观察李萌萌到底会说些什么。因为是在CPU上进行运行,速度会比较慢。", | |
).launch() | |