File size: 5,478 Bytes
1d4d1d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Imports
import gradio as gr
import os
import random
import time
import spaces
import torch
import threading
import socket
import requests

from llama_cpp import Llama
from huggingface_hub import hf_hub_download

# Variables
HF_TOKEN = os.environ.get("HF_TOKEN")

REPO = "FUGG-B22x8-2-MLdraziW/ihanaPrayizaM"[::-1]
FILE = "WizardLM-2-8x22B.IQ1_M.gguf"

TIMEOUT = 30

MAX_SEED = 9007199254740991
DEFAULT_INPUT = ""
DEFAULT_SYSTEM = "Statical is a totally normal human named 'Statical'. Statical responds very briefly, truthfully, and accurately."

TAG_USER = "USER"
TAG_ASSISTANT = "STATICAL"

DEFAULT_SEPARATOR = ","
DEFAULT_STOP_SEQUENCES = f"{TAG_USER}:,{TAG_ASSISTANT}:,</s>"

model = Llama(model_path=hf_hub_download(repo_id=REPO, filename=FILE, token=HF_TOKEN), n_ctx=32768, n_threads=48, n_batch=512, n_gpu_layers=0, verbose=True) 

# Functions
def get_seed(seed):
    seed = seed.strip()
    if seed.isdigit():
        return int(seed)
    else:
        return random.randint(0, MAX_SEED)

def generate(input=DEFAULT_INPUT, history=[], system=DEFAULT_SYSTEM, stream=False, temperature=1, top_p=0.95, top_k=50, rep_p=1.2, max_tokens=64, seed=None, separator=DEFAULT_SEPARATOR, stop_sequences=DEFAULT_STOP_SEQUENCES):
    print("[GENERATE] Model is generating...")
    
    memory = ""
    for item in history:
        if item[0]:
            memory += f"{TAG_USER}: {item[0].strip()}\n"
        if item[1]:
            memory += f"{TAG_ASSISTANT}: {item[1].strip()}</s>\n"
    prompt = f"{system.strip()}\n{memory}{TAG_USER}: {input.strip()}\n{TAG_ASSISTANT}: "
    
    print(prompt)
    
    parameters = {
        "prompt": prompt,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "repeat_penalty": rep_p,
        "max_tokens": max_tokens,
        "stop": [seq.strip() for seq in stop_sequences.split(separator)] if stop_sequences else [],
        "seed": get_seed(seed),
        "stream": stream
    }
    
    event = threading.Event()

    try:
        output = model.create_completion(**parameters)
        print("[GENERATE] Model has generated.")
        if stream:
            buffer = ""
            timer = threading.Timer(TIMEOUT, event.set)
            timer.start()
            try:
                for _, item in enumerate(output):
                    if event.is_set():
                        raise TimeoutError("[ERROR] Generation timed out.")
                    buffer += item["choices"][0]["text"]
                    yield buffer
                    timer.cancel()
                    timer = threading.Timer(TIMEOUT, event.set)
                    timer.start()
            finally:
                timer.cancel()
        else:
            yield output["choices"][0]["text"]
    except TimeoutError as e:
        yield str(e)
    finally:
        timer.cancel()

@spaces.GPU(duration=15)
def gpu():
    return
    
# Initialize
theme = gr.themes.Default(
    primary_hue="violet",
    secondary_hue="indigo",
    neutral_hue="zinc",
    spacing_size="sm",
    radius_size="lg",
    font=[gr.themes.GoogleFont('Kanit'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
    font_mono=[gr.themes.GoogleFont('Kanit'), 'ui-monospace', 'Consolas', 'monospace'],
).set(background_fill_primary='*neutral_50', background_fill_secondary='*neutral_100')

model_base = "https://huggingface.co/MaziyarPanahi/WizardLM-2-8x22B-GGUF" # [::-1]
model_quant = "https://huggingface.co/alpindale/WizardLM-2-8x22B" # [::-1]

with gr.Blocks(theme=theme) as main:
    with gr.Column():
        gr.Markdown("# πŸ‘οΈβ€πŸ—¨οΈ WizardLM")
        gr.Markdown("β €β €β€’ ⚑ A text generation inference for one of the best open-source text models: WizardLM-2-8x22B.")
        gr.Markdown("β €β €β€’ ⚠️ WARNING! The inference is very slow due to the model being HUGE; it takes 10 seconds before it starts generating; please avoid high max token parameters and sending large amounts of text; note it uses CPU because I cannot figure out how to run it in GPU without overloading the model.")
        gr.Markdown(f"β €β €β€’ πŸ”— Link to models: {model_base} (BASE), {model_quant} (QUANT)")

    with gr.Column():
        gr.ChatInterface(
            fn=generate,
            additional_inputs_accordion=gr.Accordion(label="βš™οΈ Configurations", open=False, render=False),
            additional_inputs=[
                gr.Textbox(lines=1, value=DEFAULT_SYSTEM, label="πŸͺ„ System", render=False),
                gr.Checkbox(label="⚑ Stream", value=True, render=False),
                gr.Slider(minimum=0, maximum=2, step=0.01, value=1, label="🌑️ Temperature", render=False),
                gr.Slider(minimum=0.01, maximum=0.99, step=0.01, value=0.95, label="🧲 Top P", render=False),
                gr.Slider(minimum=1, maximum=2048, step=1, value=50, label="πŸ“Š Top K", render=False),
                gr.Slider(minimum=0.01, maximum=2, step=0.01, value=1.2, label="πŸ“š Repetition Penalty", render=False),
                gr.Slider(minimum=1, maximum=2048, step=1, value=256, label="⏳ Max New Tokens", render=False),
                gr.Textbox(lines=1, value="", label="🌱 Seed (Blank for random)", render=False),
                gr.Textbox(lines=1, value=DEFAULT_SEPARATOR, label="🏷️ Stop Sequences Separator", render=False),
                gr.Textbox(lines=1, value=DEFAULT_STOP_SEQUENCES, label="πŸ›‘ Stop Sequences (Blank for none)", render=False),
            ]
        )

main.launch(show_api=False)