|
import json |
|
import os |
|
import time |
|
import random |
|
import torch |
|
import gc |
|
import re |
|
import math |
|
import gradio as gr |
|
import numpy as np |
|
import boto3 |
|
import logging |
|
from botocore.exceptions import NoCredentialsError |
|
from collections import defaultdict |
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "0" |
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
|
|
|
def download_xmad_file(): |
|
s3 = boto3.client('s3', |
|
aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), |
|
aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY')) |
|
|
|
|
|
codebooks_dir = '.codebooks' |
|
os.makedirs(codebooks_dir, exist_ok=True) |
|
|
|
temp_file_path = os.path.join(codebooks_dir, 'llama-3-8b-instruct_1bit.xmad') |
|
|
|
try: |
|
|
|
s3.download_file('xmad-quantized-models', 'llama-3-8b-instruct_1bit.xmad', temp_file_path) |
|
print("Download Successful") |
|
|
|
|
|
os.chmod(codebooks_dir, 0o700) |
|
|
|
except NoCredentialsError: |
|
print("Credentials not available") |
|
|
|
download_xmad_file() |
|
|
|
def b2mb(x): |
|
""" |
|
Convert bytes to megabytes. |
|
""" |
|
return int(x / 2**20) |
|
|
|
|
|
class TorchTracemalloc: |
|
""" |
|
A context manager that clears GPU memory |
|
and returns GPU peak memory & GPU memory usage. |
|
""" |
|
track_memory_consumption = [] |
|
|
|
def __enter__(self): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
torch.cuda.reset_peak_memory_stats() |
|
self.begin = torch.cuda.memory_allocated() |
|
return self |
|
|
|
def __exit__(self, *exc): |
|
torch.cuda.synchronize() |
|
self.end = torch.cuda.memory_allocated() |
|
self.peak = torch.cuda.max_memory_allocated() |
|
self.used = b2mb(self.end - self.begin) |
|
self.peaked = b2mb(self.peak - self.begin) |
|
TorchTracemalloc.track_memory_consumption.append(self.peaked) |
|
|
|
def clear_gpu_memory(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
print("GPU memory cleared.") |
|
|
|
|
|
def format_response(dialog, response): |
|
question = next((turn['content'] for turn in dialog if turn['role'] == 'user'), 'No question found') |
|
return {"question": question, "answer": response} |
|
|
|
|
|
global_model = None |
|
global_tokenizer = None |
|
|
|
def load_model_and_tokenizer(model_name, dtype, kv_bits): |
|
global global_model, global_tokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
special_tokens = {"pad_token": "<PAD>"} |
|
tokenizer.add_special_tokens(special_tokens) |
|
|
|
config = AutoConfig.from_pretrained(model_name) |
|
if kv_bits != "unquantized": |
|
quantizer_path = f".codebooks/{model_name.split('/')[-1]}_{kv_bits}bit.xmad" |
|
setattr(config, "quantizer_path", quantizer_path) |
|
|
|
if dtype == "bf16": |
|
dtype = torch.bfloat16 |
|
elif dtype == "fp16": |
|
dtype = torch.float16 |
|
elif dtype == "fp32": |
|
dtype = torch.float32 |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=dtype, device_map="auto") |
|
|
|
print(f"Quantizer path in model config: {model.config.quantizer_path}") |
|
logging.info(f"Quantizer path in model config: {model.config.quantizer_path}") |
|
|
|
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]: |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
tokenizer.padding_side = "left" |
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
global_model = model |
|
global_tokenizer = tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_questions(prompts_path, custom_questions): |
|
selected_dialogs = [] |
|
if custom_questions: |
|
for question in custom_questions: |
|
if question.strip(): |
|
custom_dialog = [{"role": "user", "content": question}] |
|
selected_dialogs.append(custom_dialog) |
|
return selected_dialogs |
|
|
|
|
|
def markdown_to_plain_text(markdown_text): |
|
|
|
markdown_text = re.sub(r'\*\*(.*?)\*\*', r'\1'.upper(), markdown_text) |
|
|
|
markdown_text = re.sub(r'\*(.*?)\*', r'\1', markdown_text) |
|
|
|
markdown_text = re.sub(r'### ', '', markdown_text) |
|
|
|
markdown_text = re.sub(r'^\s*[-*]\s+', '', markdown_text, flags=re.MULTILINE) |
|
|
|
markdown_text = re.sub(r'[`~>]', '', markdown_text) |
|
return markdown_text |
|
|
|
def infer(model_name, dialogs, num_new_tokens, temperature, dtype, kv_bits, progress=gr.Progress()): |
|
print("Starting inference...") |
|
global global_model, global_tokenizer |
|
|
|
model = global_model |
|
tokenizer = global_tokenizer |
|
|
|
batch_inputs = [ |
|
tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=True) |
|
for dialog in dialogs |
|
] |
|
|
|
responses = [] |
|
start_time = time.time() |
|
batch_size = min(100, len(dialogs)) |
|
num_dialogs = len(dialogs) |
|
total_time = 0 |
|
total_tokens = 0 |
|
total_ttft = 0 |
|
|
|
memory_avg = [] |
|
tokens_per_sec_avg = [] |
|
time_to_first_token_avg = [] |
|
responses_by_batch_size = defaultdict(list) |
|
batch_generation_time = 0 |
|
total_generation_time = 0 |
|
|
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>"), |
|
] |
|
|
|
with TorchTracemalloc() as tt: |
|
for i in range(0, num_dialogs, batch_size): |
|
batch = batch_inputs[i : i + batch_size] |
|
try: |
|
encoded_inputs = tokenizer( |
|
batch, |
|
padding=True, |
|
truncation=False, |
|
return_tensors="pt", |
|
) |
|
|
|
input_ids = encoded_inputs["input_ids"].to(model.device) |
|
attention_mask = encoded_inputs["attention_mask"].to(model.device) |
|
|
|
torch.cuda.synchronize() |
|
start_time = time.perf_counter() |
|
|
|
with torch.no_grad(): |
|
output_tokens = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
max_new_tokens=num_new_tokens, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
temperature=temperature, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=terminators, |
|
) |
|
|
|
torch.cuda.synchronize() |
|
end_time = time.perf_counter() |
|
|
|
batch_time = end_time - start_time |
|
total_time += batch_time |
|
batch_generation_time += batch_time |
|
total_generation_time += batch_time |
|
total_tokens += output_tokens.numel() |
|
|
|
if i == 0: |
|
total_ttft = batch_time |
|
|
|
decoded_outputs = tokenizer.batch_decode(output_tokens, skip_special_tokens=True) |
|
|
|
for j, response in enumerate(decoded_outputs): |
|
original_dialog = dialogs[i + j] |
|
formatted_responses = format_response(original_dialog, response) |
|
responses.append(formatted_responses) |
|
|
|
formatted_responses = "\n\n====================\n\n".join([f"**Question**:\t{res['question']}\n\n**Answer**: {res['answer'][4+len(res['question'])+11:]}" for res in responses]) |
|
plain_text_responses = markdown_to_plain_text(formatted_responses) |
|
yield plain_text_responses |
|
progress(i, desc="Processing batches") |
|
|
|
torch.cuda.empty_cache() |
|
|
|
except Exception as e: |
|
print(f"Error processing batch {i//batch_size + 1}: {str(e)}") |
|
continue |
|
|
|
elapsed_time = total_time |
|
tokens_per_second = total_tokens / total_time if total_time > 0 else 0 |
|
total_memory_consumption = np.sum(TorchTracemalloc.track_memory_consumption) |
|
avg_memory_consumption = total_memory_consumption / num_dialogs |
|
|
|
ttft = total_ttft / batch_size if batch_size > 0 else 0 |
|
|
|
print(f"Inference completed in {elapsed_time:.2f} seconds.") |
|
|
|
yield { |
|
"Time Taken (seconds)": elapsed_time, |
|
"Tokens per Second": tokens_per_second, |
|
"Time to First Token (seconds)": ttft, |
|
"Formatted Responses": plain_text_responses, |
|
"Memory Consumption per Question (MB)": avg_memory_consumption, |
|
"Total Memory Consumption (MB)": total_memory_consumption, |
|
"Num Dialogs": num_dialogs |
|
} |
|
|
|
|
|
def demo(num_new_tokens, temperature, custom_questions_text, kv_bits=1, progress=gr.Progress()): |
|
custom_questions = custom_questions_text.split("\n") |
|
print("Loading questions...") |
|
dialogs = load_questions("chats_sys_none.json", custom_questions) |
|
print(f"{len(dialogs)} questions loaded. Starting inference...") |
|
|
|
result_gen = infer("NousResearch/Meta-Llama-3-8B-Instruct", dialogs, num_new_tokens, temperature, "fp16", kv_bits, progress=progress) |
|
|
|
formatted_responses = "" |
|
num_dialogs = 0 |
|
for result in result_gen: |
|
if isinstance(result, str): |
|
formatted_responses = result |
|
yield None, None, None, None, None, None, None, formatted_responses |
|
else: |
|
time_taken = result["Time Taken (seconds)"] |
|
tokens_per_second = result["Tokens per Second"] |
|
ttft = result["Time to First Token (seconds)"] |
|
avg_memory_consumption = result["Memory Consumption per Question (MB)"] |
|
total_memory_consumption = result["Total Memory Consumption (MB)"] |
|
num_dialogs = result["Num Dialogs"] |
|
formatted_responses = result["Formatted Responses"] |
|
yield time_taken, tokens_per_second, ttft, avg_memory_consumption, num_dialogs, total_memory_consumption, formatted_responses |
|
|
|
|
|
|
|
with open("chats_sys_none.json", "r") as file: |
|
json_data = json.load(file) |
|
|
|
|
|
def load_default_questions(): |
|
random.shuffle(json_data) |
|
default_questions = [dialog[0]['content'] for dialog in json_data[:60] if 'content' in dialog[0]] |
|
return "\n".join(default_questions) |
|
|
|
|
|
def load_questions_action(): |
|
return load_default_questions() |
|
|
|
|
|
css = """ |
|
body, html { |
|
height: 100vh; |
|
margin: 0; |
|
} |
|
|
|
.gradio-container { |
|
height: 100vh; |
|
} |
|
|
|
#main-row { |
|
height: 100%; |
|
display: flex; |
|
} |
|
|
|
#control-panel{ |
|
height: 100%; |
|
box-sizing: border-box; |
|
display: flex; |
|
flex-direction: column; |
|
overflow: hidden; |
|
flex: 1; |
|
} |
|
|
|
#control-panel, #formatted-responses-container { |
|
height: 100%; |
|
box-sizing: border-box; |
|
display: flex; |
|
flex-direction: column; |
|
overflow: hidden; |
|
flex: 1; |
|
} |
|
|
|
#control-panel { |
|
flex: 1; |
|
padding-bottom: 1vh; /* Add some padding to the bottom */ |
|
} |
|
|
|
#custom-questions-text { |
|
height: 30vh; /* Fixed height for custom questions text */ |
|
overflow-y: auto; |
|
} |
|
|
|
#metrics-panel { |
|
display: flex; |
|
flex-wrap: wrap; |
|
flex-shrink: 0; |
|
height: auto; /* Let the panel size adjust based on its content */ |
|
} |
|
|
|
#metrics-panel .metric { |
|
flex: 1 1 48%; |
|
min-width: 10vw; |
|
box-sizing: border-box; |
|
} |
|
|
|
#buttons-container { |
|
display: flex; |
|
justify-content: space-between; |
|
height: 6vh; /* Fixed height for buttons container */ |
|
flex-shrink: 0; |
|
margin-bottom: 1vh; /* Add margin to prevent cutting off */ |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as app: |
|
with gr.Row(elem_id="main-row", equal_height=True): |
|
with gr.Column(elem_id="control-panel", scale=1): |
|
num_new_tokens = gr.Slider(label="Number of New Tokens", minimum=128, maximum=2048, step=128, value=512) |
|
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.4) |
|
custom_questions_text = gr.Textbox( |
|
label="Custom Questions", |
|
placeholder="Type your custom questions here, one per line... \nOr press the \"Load Default Questions\" button to load 60 random default questions. \nAdd a question by adding a new line, or delete lines to decrease the number of questions. \nP.S.: YOU CAN COMPARE OUR SPEED & QUALITY VS. GPT4o by COPY/PASTING ALL QUESTIONS HERE INTO CHATGPT", |
|
autoscroll=False, |
|
container=False, |
|
lines=5, |
|
elem_id="custom-questions-text" |
|
) |
|
with gr.Row(elem_id="metrics-panel"): |
|
time_taken = gr.Number(label="Time Taken (seconds)", interactive=False, elem_classes=["metric"]) |
|
tokens_per_second = gr.Number(label="Tokens per Second", interactive=False, elem_classes=["metric"]) |
|
ttft = gr.Number(label="Time to First Token (seconds)", interactive=False, elem_classes=["metric"]) |
|
total_memory_consumption = gr.Number(label="Memory Consumption (MB)", interactive=False, elem_classes=["metric"]) |
|
num_dialogs = gr.Number(label="Dialogs Processed", interactive=False, elem_classes=["metric"]) |
|
avg_memory_consumption = gr.Number(label="Mem. Consumption per Question (MB)", interactive=False, elem_classes=["metric"]) |
|
with gr.Row(elem_id="buttons-container"): |
|
load_questions_btn = gr.Button("Load Default Questions") |
|
demo_btn = gr.Button("Run Inference", elem_id="run-inference-btn", variant="primary") |
|
|
|
formatted_responses = gr.Textbox( |
|
label="Formatted Responses", |
|
elem_id="formatted-responses", |
|
value="No responses yet. Run the inference to see results.", |
|
lines=37, |
|
container=False, |
|
autoscroll=False, |
|
show_copy_button=True |
|
) |
|
|
|
load_questions_btn.click(fn=load_questions_action, inputs=[], outputs=custom_questions_text) |
|
demo_btn.click(demo, inputs=[num_new_tokens, temperature, custom_questions_text], outputs=[time_taken, tokens_per_second, ttft, avg_memory_consumption, num_dialogs, total_memory_consumption, formatted_responses]) |
|
|
|
if __name__ == "__main__": |
|
print("Loading model and tokenizer on startup...") |
|
load_model_and_tokenizer("NousResearch/Meta-Llama-3-8B-Instruct", "fp16", "1") |
|
print("Model and tokenizer loaded. Starting Gradio interface...") |
|
username = os.getenv("AUTH_USERNAME") |
|
password = os.getenv("AUTH_PASSWORD") |
|
app.launch(auth=(username, password)) |
|
|