Spaces:
Sleeping
Sleeping
File size: 4,799 Bytes
fdd975b c7c4c72 8b1e84b fdd975b c7c4c72 fdd975b 8b1e84b c7c4c72 8b1e84b fdd975b 8b1e84b fdd975b 8b1e84b fdd975b 8b1e84b c7c4c72 fdd975b 8b1e84b fdd975b c7c4c72 fdd975b c7c4c72 fdd975b 8b1e84b fdd975b 8b1e84b fdd975b 8b1e84b c7c4c72 fdd975b c7c4c72 8b1e84b fdd975b c7c4c72 fdd975b c7c4c72 fdd975b 8b1e84b fdd975b 8b1e84b fdd975b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, TorchAoConfig
from threading import Thread
import os, subprocess, torch
from torchao.quantization import Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, Float8DynamicActivationFloat8WeightConfig
from torchao.dtypes import Int4CPULayout
#subprocess.run("pip list", shell=True)
IS_COMPILE = False if torch.cuda.is_available() else True
device = "cuda" if torch.cuda.is_available() else "cpu"
# https://huggingface.co/docs/transformers/en/quantization/torchao?examples-CPU=int8-dynamic-and-weight-only
if torch.cuda.is_available():
quant_config = Float8DynamicActivationFloat8WeightConfig()
else:
#quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout())
quant_config = Int8DynamicActivationInt8WeightConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)
#checkpoint = "HuggingFaceTB/SmolLM2-135M-Instruct"
checkpoint = "unsloth/gemma-3-4b-it"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
#model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32).to(device)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map=device, quantization_config=quantization_config).eval()
if IS_COMPILE:
model.generation_config.cache_implementation = "static"
input_text = "Warming up."
input_ids = tokenizer(input_text, return_tensors="pt").to(device)
output = model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
def get_duration(message, history, system_message, max_tokens, temperature, top_p, duration):
return duration
@spaces.GPU(duration=get_duration)
@torch.inference_mode()
def respond_stream(message, history, system_message, max_tokens, temperature, top_p, duration):
messages = [{"role": "system", "content": system_message}] + history + [{"role": "user", "content": message}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
gen_kwargs = dict(
input_ids=inputs["input_ids"],
#attention_mask=inputs["attention_mask"],
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
num_beams=1,
output_scores=False,
)
if IS_COMPILE: gen_kwargs["cache_implementation"] = "static"
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
partial = ""
for piece in streamer:
partial += piece
yield partial
@spaces.GPU(duration=get_duration)
@torch.inference_mode()
def respond(message, history, system_message, max_tokens, temperature, top_p, duration):
messages = [{"role": "system", "content": system_message}] + history + [{"role": "user", "content": message}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
gen_kwargs = dict(
input_ids=inputs["input_ids"],
#attention_mask=inputs["attention_mask"],
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
num_beams=1,
output_scores=False,
)
if IS_COMPILE: gen_kwargs["cache_implementation"] = "static"
outputs = model.generate(**gen_kwargs)
gen_ids = outputs[0][inputs["input_ids"].shape[-1]:]
return tokenizer.decode(gen_ids, skip_special_tokens=True)
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
gr.Slider(minimum=1, maximum=360, value=30, step=1, label="Duration"),
],
)
with gr.Blocks() as demo:
chatbot.render()
if __name__ == "__main__":
demo.queue().launch()
|