NeuralChat / app.py
Tonic's picture
Update app.py
0d5c130
raw
history blame
No virus
3.96 kB
import os
import math
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import sentencepiece
title = "Welcome to Tonic's 🐋🐳Orca-2-13B (in 8bit)!"
description = "You can use [🐋🐳microsoft/Orca-2-13b](https://huggingface.co/microsoft/Orca-2-13b) via API using Gradio by scrolling down and clicking Use 'Via API' or privately by [cloning this space on huggingface](https://huggingface.co/spaces/Tonic1/TonicsOrca2?duplicate=true) . [Join my active builders' server on discord](https://discord.gg/VqTxc76K3u). Let's build together! Big thanks to the HuggingFace Organisation for the Community Grant."
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "microsoft/Orca-2-13b"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
class OrcaChatBot:
def __init__(self, model, tokenizer, system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
self.model = model
self.tokenizer = tokenizer
self.system_message = system_message
self.conversation_history = []
def update_conversation_history(self, user_message, assistant_message):
self.conversation_history.append(("user", user_message))
self.conversation_history.append(("assistant", assistant_message))
def format_prompt(self):
prompt = f"<|im_start|>assistant\n{self.system_message}<|im_end|>\n"
for role, message in self.conversation_history:
if message.strip():
prompt += f"<|im_start|>{role}\n{message}<|im_end|>\n"
# if role == "assistant":
# prompt += f"<|im_end|>\n"
prompt += "<|im_start|> assistant\n"
return prompt
def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
self.update_conversation_history(user_message, "")
prompt = self.format_prompt()
inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
input_ids = inputs["input_ids"].to(self.model.device)
output_ids = self.model.generate(
input_ids,
max_length=input_ids.shape[1] + max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=True
)
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
self.update_conversation_history("", response)
return response
Orca_bot = OrcaChatBot(model, tokenizer)
def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty):
full_message = f"{system_message}\n{user_message}" if system_message else user_message
return Orca_bot.predict(full_message, temperature, max_new_tokens, top_p, repetition_penalty)
iface = gr.Interface(
fn=gradio_predict,
title=title,
description=description,
inputs=[
gr.Textbox(label="Your Message", type="text", lines=3),
gr.Textbox(label="Introduce a Character Here or Set a Scene (system prompt)", type="text", lines=2),
gr.Slider(label="Max new tokens", value=420, minimum=25, maximum=2056, step=1),
gr.Slider(label="Temperature", value=0.1, minimum=0.05, maximum=1.0, step=0.05),
gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
],
outputs="text",
theme="ParityError/Anime"
)
iface.launch()