File size: 5,791 Bytes
d29e01f
4fda9e6
 
 
 
d29e01f
 
 
 
4fda9e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d29e01f
 
 
 
 
 
 
 
 
 
3012fc9
4fda9e6
 
 
1b4dab3
 
 
4fda9e6
1b4dab3
3012fc9
d29e01f
3012fc9
 
 
 
 
d29e01f
3012fc9
 
 
 
d29e01f
4fda9e6
 
 
d29e01f
4fda9e6
 
 
 
 
 
 
 
 
 
3012fc9
4fda9e6
 
 
 
 
 
 
 
1b4dab3
 
4fda9e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d29e01f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
from transformers import AutoTokenizer, TextIteratorStreamer
from auto_gptq import AutoGPTQForCausalLM
from threading import Thread
import torch

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""

# --- Model Configuration ---
MODEL_NAME_OR_PATH = "TheBloke/Wizard-Vicuna-13B-Uncensored-SuperHOT-8K-GPTQ"
MODEL_BASENAME = "wizard-vicuna-13b-uncensored-superhot-8k-GPTQ-4bit-128g.no-act.order"
# Set to False if you don't have a CUDA-enabled GPU or want to force CPU (slower)
# AutoGPTQ will try to use 'cuda:0' by default if available. device_map='auto' helps.
USE_CUDA = torch.cuda.is_available()

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)

print(f"Loading model {MODEL_NAME_OR_PATH}...")
# For AutoGPTQ, device_map can be 'auto', 'cuda:0', 'cpu', etc.
# 'auto' will try to use GPU if available.
# trust_remote_code=True is necessary for this model's extended context.
model = AutoGPTQForCausalLM.from_quantized(
    MODEL_NAME_OR_PATH,
    model_basename=MODEL_BASENAME,
    use_safetensors=True,
    trust_remote_code=True,
    device_map="auto", # Automatically selects GPU if available, else CPU
    quantize_config=None # Model is already quantized
)
# The model card specifies setting seqlen, though with trust_remote_code=True it might be handled.
# It's good practice to set it if mentioned.
# model.seqlen = 8192 # AutoGPTQ's from_quantized doesn't directly expose setting seqlen this way after load.
# The config.json for this model should have max_position_embeddings = 8192.
# If issues arise with context, this might need further investigation or direct config modification.
print("Model loaded.")

# Determine the device the model was loaded on, for tokenizing inputs
# If device_map="auto", model.device might not be straightforward.
# Transformers usually handle input tensor placement correctly with device_map="auto".
# We'll try to get it, otherwise default to cuda if available, else cpu.
try:
    DEVICE = model.device
except AttributeError:
    DEVICE = torch.device("cuda:0" if USE_CUDA else "cpu")
print(f"Model is on device: {DEVICE}")


def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    prompt_parts = []
    if system_message and system_message.strip():
        # How a system message is used can vary. For this model, prepending it might work.
        # Or, it could be part of the initial "USER:" turn if the model expects that.
        # The example prompt format for some models is:
        # USER: {prompt}
        # ASSISTANT: {response}
        # We will integrate system_message as part of the first user turn or as general context.
        # For now, let's prepend it simply to the overall prompt.
        prompt_parts.append(system_message)

    for user_msg, assistant_msg in history:
        if user_msg:
            prompt_parts.append(f"USER: {user_msg}")
        if assistant_msg:
            prompt_parts.append(f"ASSISTANT: {assistant_msg}")

    prompt_parts.append(f"USER: {message}")
    prompt_parts.append("ASSISTANT:") # Model will generate content starting from here

    full_prompt = "\n".join(prompt_parts)

    # Tokenize the input
    # The .to(DEVICE) is important to move tensors to the same device as the model
    inputs = tokenizer(full_prompt, return_tensors="pt", add_special_tokens=True).to(DEVICE)

    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,        # Don't return the prompt in the output
        skip_special_tokens=True # Don't return special tokens like <s> or </s>
    )

    # Generation parameters
    generation_kwargs = dict(
        **inputs, # Pass all keys from tokenizer output (input_ids, attention_mask)
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature if temperature > 0 else 0.01, # Temp 0 can cause issues, ensure small positive
        top_p=top_p if top_p < 1.0 else 0.99, # Top_p 1.0 can be problematic, ensure slightly less
        # repetition_penalty=1.15 # Optional, from model card example
        # Typical generation params:
        # do_sample=True if temperature > 0 else False, # auto-set by presence of temp > 0 for AutoGPTQ/HF
        # top_k=50, # Another sampling param
    )
    
    # Ensure temperature is valid for sampling
    if generation_kwargs['temperature'] <= 1e-4: # Using a small epsilon for float comparison
        generation_kwargs['temperature'] = 0.01 # A very small value for near-deterministic
        generation_kwargs['do_sample'] = False
    else:
        generation_kwargs['do_sample'] = True


    # Run generation in a separate thread to not block the main thread
    # This allows Gradio to update UI while text is streaming in
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    response = ""
    for new_text in streamer:
        response += new_text
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()