djstrong's picture
Update app.py
3dfcfa7 verified
raw
history blame
9.71 kB
import os
import json
import subprocess
from threading import Thread
import requests
import random
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from huggingface_hub import HfApi
from datetime import datetime
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
MODEL_ID = os.environ.get("MODEL_ID")
CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE")
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH"))
COLOR = os.environ.get("COLOR")
EMOJI = os.environ.get("EMOJI")
DESCRIPTION = os.environ.get("DESCRIPTION")
DISCORD_WEBHOOK = os.environ.get("DISCORD_WEBHOOK")
TOKEN = os.environ.get("TOKEN")
api = HfApi()
def send_discord(i,o):
url = DISCORD_WEBHOOK
embed1 = {
"description": i,
"title": "Input"
}
embed2 = {
"description": o,
"title": "Output"
}
data = {
"content": "https://huggingface.co/spaces/speakleash/Bielik-7B-Instruct-v0.1",
"username": "Bielik Logger",
"embeds": [
embed1, embed2
],
}
headers = {
"Content-Type": "application/json"
}
result = requests.post(url, json=data, headers=headers)
if 200 <= result.status_code < 300:
print(f"Webhook sent {result.status_code}")
else:
print(f"Not sent with {result.status_code}, response:\n{result.json()}")
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype='auto',
attn_implementation="flash_attention_2",
)
@spaces.GPU()
def generate(instruction, stop_tokens, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
if input_ids.shape[1] > CONTEXT_LENGTH:
input_ids = input_ids[:, -CONTEXT_LENGTH:]
generate_kwargs = dict(
{"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
streamer=streamer,
do_sample=True if temperature else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
if new_token in stop_tokens:
break
yield "".join(outputs)
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
repetition_penalty=float(repetition_penalty)
print('LLL', [message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p])
# Format history with a given chat template
if CHAT_TEMPLATE == "ChatML":
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
for human, assistant in history:
instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
elif CHAT_TEMPLATE == "Mistral Instruct":
stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
instruction = '<s>[INST] ' + system_prompt
for human, assistant in history:
instruction += human + ' [/INST] ' + assistant + '</s>[INST]'
instruction += ' ' + message + ' [/INST]'
elif CHAT_TEMPLATE == "Bielik":
stop_tokens = ["</s>"]
prompt_builder = ["<s>[INST] "]
if system_prompt:
prompt_builder.append(f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n")
for human, assistant in history:
prompt_builder.append(f"{human} [/INST] {assistant}</s>[INST] ")
prompt_builder.append(f"{message} [/INST]")
instruction = ''.join(prompt_builder)
else:
raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
print(instruction)
for output_text in generate(instruction, stop_tokens, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
yield output_text
send_discord(instruction, output_text)
hfapi = HfApi()
day=datetime.now().strftime("%Y-%m-%d")
timestamp=datetime.now().timestamp()
dd={
'message': message,
'history': history,
'system_prompt':system_prompt,
'temperature':temperature,
'max_new_tokens':max_new_tokens,
'top_k':top_k,
'repetition_penalty':repetition_penalty,
'top_p':top_p,
'instruction':instruction,
'output':output_text,
'precision': 'auto '+str(model.dtype),
}
hfapi.upload_file(
path_or_fileobj=json.dumps(dd, indent=2, ensure_ascii=False).encode('utf-8'),
path_in_repo=f"{day}/{timestamp}.json",
repo_id="speakleash/bielik-logs",
repo_type="dataset",
commit_message=f"X",
token=TOKEN,
run_as_future=True
)
on_load="""
async()=>{
alert("Przed skorzystaniem z us艂ugi u偶ytkownicy musz膮 wyrazi膰 zgod臋 na nast臋puj膮ce warunki:\\n\\nProsz臋 pami臋ta膰, 偶e przedstawiony tutaj model jest narz臋dziem eksperymentalnym, kt贸re wci膮偶 jest rozwijane i doskonalone.\\n\\nW trakcie procesu tworzenia modelu podj臋to 艣rodki maj膮ce na celu zminimalizowanie ryzyka generowania tre艣ci wulgarnych, niedozwolonych lub nieodpowiednich. Niemniej jednak, w rzadkich przypadkach, niepo偶膮dane tre艣ci mog膮 zosta膰 wygenerowane. Je艣li napotkaj膮 Pa艅stwo na jakiekolwiek tre艣ci uznane za nieodpowiednie lub naruszaj膮ce zasady, prosimy o kontakt w celu zg艂oszenia tego faktu. Dzi臋ki Pa艅stwa informacjom b臋dziemy mogli podejmowa膰 dalsze dzia艂ania maj膮ce na celu popraw臋 i rozw贸j modelu, tak aby by艂 on bezpieczny i przyjazny dla u偶ytkownik贸w.\\n\\nNie wolno u偶ywa膰 modelu do cel贸w nielegalnych, szkodliwych, brutalnych, rasistowskich lub seksualnych. Prosz臋 nie przesy艂a膰 偶adnych prywatnych informacji. Serwis gromadzi dane dialogowe u偶ytkownika i zastrzega sobie prawo do ich rozpowszechniania na podstawie licencji Creative Commons Uznanie autorstwa (CC-BY) lub podobnej.");
}
"""
def vote(chatbot, data: gr.LikeData):
day=datetime.now().strftime("%Y-%m-%d")
timestamp=datetime.now().timestamp()
api.upload_file(
path_or_fileobj=json.dumps({"history":chatbot, 'index': data.index, 'liked': data.liked}, indent=2, ensure_ascii=False).encode('utf-8'),
path_in_repo=f"liked/{day}/{timestamp}.json",
repo_id="speakleash/bielik-logs",
repo_type="dataset",
commit_message=f"L",
token=TOKEN,
run_as_future=True
)
# Create Gradio interface
def update_examples():
exs = [
["Kim jeste艣?"],
["Ile to jest 9+2-1?"],
["Napisz mi co艣 mi艂ego."]
]
random.shuffle(exs)
return gr.Dataset(samples=exs)
with gr.Blocks(js=on_load) as demo:
chatbot = gr.Chatbot(label="Chatbot", likeable=True, render=False)
chatbot.like(vote, [chatbot], None)
chat = gr.ChatInterface(
predict,
chatbot=chatbot,
title=EMOJI + " " + MODEL_NAME,
description=DESCRIPTION,
examples=[
["Kim jeste艣?"],
["Ile to jest 9+2-1?"],
["Napisz mi co艣 mi艂ego."]
],
additional_inputs_accordion=gr.Accordion(label="鈿欙笍 Parameters", open=False, render=False),
additional_inputs=[
gr.Textbox("", label="System prompt", render=False),
gr.Slider(0, 1, 0.6, label="Temperature", render=False),
gr.Slider(128, 4096, 1024, label="Max new tokens", render=False),
gr.Slider(1, 80, 40, step=1, label="Top K sampling", render=False),
gr.Slider(0, 2, 1.1, label="Repetition penalty", render=False),
gr.Slider(0, 1, 0.95, label="Top P sampling", render=False),
],
theme=gr.themes.Soft(primary_hue=COLOR),
)
demo.load(update_examples, None, chat.examples_handler.dataset)
demo.queue(max_size=20).launch()
# chatbot = gr.Chatbot(label="Chatbot", likeable=True)
# chatbot.like(vote, None, None)
# gr.ChatInterface(
# predict,
# chatbot=chatbot,
# title=EMOJI + " " + MODEL_NAME,
# description=DESCRIPTION,
# examples=[
# ["Kim jeste艣?"],
# ["Ile to jest 9+2-1?"],
# ["Napisz mi co艣 mi艂ego."]
# ],
# additional_inputs_accordion=gr.Accordion(label="鈿欙笍 Parameters", open=False),
# additional_inputs=[
# gr.Textbox("", label="System prompt"),
# gr.Slider(0, 1, 0.6, label="Temperature"),
# gr.Slider(128, 4096, 1024, label="Max new tokens"),
# gr.Slider(1, 80, 40, label="Top K sampling"),
# gr.Slider(0, 2, 1.1, label="Repetition penalty"),
# gr.Slider(0, 1, 0.95, label="Top P sampling"),
# ],
# theme=gr.themes.Soft(primary_hue=COLOR),
# js=on_load,
# ).queue().launch()