Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2023 WisdomShell Inc. All Rights Reserved. | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# This code is based on Qwen's web Demo. It has been modified from | |
# its original forms to accommodate CodeShell. | |
# Copyright (c) Alibaba Cloud. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""A simple web interactive chat demo based on gradio.""" | |
import os | |
from argparse import ArgumentParser | |
import gradio as gr | |
import mdtex2html | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers.generation import GenerationConfig | |
DEFAULT_CKPT_PATH = 'WisdomShell/CodeShell-7B-Chat-int4' | |
def _load_model_tokenizer(args): | |
tokenizer = AutoTokenizer.from_pretrained( | |
'WisdomShell/CodeShell-7B-Chat-int4', trust_remote_code=True, resume_download=True, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
'WisdomShell/CodeShell-7B-Chat-int4', | |
trust_remote_code=True, | |
resume_download=True, | |
torch_dtype=torch.bfloat16 | |
).eval() | |
config = GenerationConfig.from_pretrained( | |
'WisdomShell/CodeShell-7B-Chat-int4', trust_remote_code=True, resume_download=True, | |
) | |
return model, tokenizer, config | |
def postprocess(self, y): | |
if y is None: | |
return [] | |
for i, (message, response) in enumerate(y): | |
y[i] = ( | |
None if message is None else mdtex2html.convert(message), | |
None if response is None else mdtex2html.convert(response), | |
) | |
return y | |
gr.Chatbot.postprocess = postprocess | |
def _parse_text(text): | |
lines = text.split("\n") | |
lines = [line for line in lines if line != ""] | |
count = 0 | |
for i, line in enumerate(lines): | |
if "```" in line: | |
count += 1 | |
items = line.split("`") | |
if count % 2 == 1: | |
lines[i] = f'<pre><code class="language-{items[-1]}">' | |
else: | |
lines[i] = f"<br></code></pre>" | |
else: | |
if i > 0: | |
if count % 2 == 1: | |
line = line.replace("`", r"\`") | |
line = line.replace("<", "<") | |
line = line.replace(">", ">") | |
line = line.replace(" ", " ") | |
line = line.replace("*", "*") | |
line = line.replace("_", "_") | |
line = line.replace("-", "-") | |
line = line.replace(".", ".") | |
line = line.replace("!", "!") | |
line = line.replace("(", "(") | |
line = line.replace(")", ")") | |
line = line.replace("$", "$") | |
lines[i] = "<br>" + line | |
text = "".join(lines) | |
return text | |
def _gc(): | |
import gc | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def _launch_demo(args, model, tokenizer, config): | |
def predict(_query, _chatbot, _task_history): | |
print(f"User: {_parse_text(_query)}") | |
_chatbot.append((_parse_text(_query), "")) | |
full_response = "" | |
for response in model.chat(_query, _task_history, tokenizer, generation_config=config, stream=True): | |
response = response.replace('|end|', '') | |
response = response.replace('|<end>|', '') | |
_chatbot[-1] = (_parse_text(_query), _parse_text(response)) | |
yield _chatbot | |
full_response = _parse_text(response) | |
print(f"History: {_task_history}") | |
_task_history.append((_query, full_response)) | |
print(f"CodeShell-Chat: {_parse_text(full_response)}") | |
def regenerate(_chatbot, _task_history): | |
if not _task_history: | |
yield _chatbot | |
return | |
item = _task_history.pop(-1) | |
_chatbot.pop(-1) | |
yield from predict(item[0], _chatbot, _task_history) | |
def reset_user_input(): | |
return gr.update(value="") | |
def reset_state(_chatbot, _task_history): | |
_task_history.clear() | |
_chatbot.clear() | |
_gc() | |
return _chatbot | |
with gr.Blocks() as demo: | |
gr.Markdown("""<center><font size=8>CodeShell-Chat Bot</center>""") | |
chatbot = gr.Chatbot(label='CodeShell-Chat', elem_classes="control-height") | |
query = gr.Textbox(lines=2, label='Input') | |
task_history = gr.State([]) | |
with gr.Row(): | |
empty_btn = gr.Button("🧹 Clear History (清除历史)") | |
submit_btn = gr.Button("🚀 Submit (发送)") | |
regen_btn = gr.Button("🤔️ Regenerate (重试)") | |
submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True) | |
submit_btn.click(reset_user_input, [], [query]) | |
empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True) | |
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True) | |
gr.Markdown("""\ | |
<font size=2>Note: This demo is governed by the original license of CodeShell. \ | |
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \ | |
including hate speech, violence, pornography, deception, etc. \ | |
(注:本演示受CodeShell的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\ | |
包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""") | |
demo.queue().launch() | |
args = {} | |
print("Loading model...") | |
model, tokenizer, config = _load_model_tokenizer(args) | |
print("Model loaded, launching demo...") | |
_launch_demo(args, model, tokenizer, config) | |