Spaces:
Running
Running
import os | |
from pathlib import Path | |
import requests | |
import shutil | |
import torch | |
from threading import Event, Thread | |
from transformers import AutoConfig, AutoTokenizer | |
from optimum.intel.openvino import OVModelForCausalLM | |
import openvino as ov | |
import openvino.properties as props | |
import openvino.properties.hint as hints | |
import openvino.properties.streams as streams | |
import gradio as gr | |
from llm_config import SUPPORTED_LLM_MODELS | |
from notebook_utils import device_widget | |
# Initialize model language options | |
model_languages = list(SUPPORTED_LLM_MODELS) | |
# Gradio components for selecting model language and model ID | |
model_language = gr.Dropdown( | |
choices=model_languages, | |
value=model_languages[0], | |
label="Model Language" | |
) | |
# Gradio dropdown for selecting model ID based on language | |
def update_model_id(model_language_value): | |
model_ids = list(SUPPORTED_LLM_MODELS[model_language_value]) | |
return model_ids[0], gr.update(choices=model_ids) | |
model_id = gr.Dropdown( | |
choices=[], # will be dynamically populated | |
label="Model", | |
value=None | |
) | |
model_language.change(update_model_id, inputs=model_language, outputs=[model_id]) | |
# Gradio checkbox for preparing INT4 model | |
prepare_int4_model = gr.Checkbox( | |
value=True, | |
label="Prepare INT4 Model" | |
) | |
# Gradio checkbox for enabling AWQ (depends on INT4 checkbox) | |
enable_awq = gr.Checkbox( | |
value=False, | |
label="Enable AWQ", | |
visible=False | |
) | |
# Device selection widget (e.g., CPU or GPU) | |
device = device_widget("CPU", exclude=["NPU"]) | |
# Model directory and setup based on selections | |
def get_model_path(model_language_value, model_id_value): | |
model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value] | |
pt_model_id = model_configuration["model_id"] | |
pt_model_name = model_id_value.split("-")[0] | |
int4_model_dir = Path(model_id_value) / "INT4_compressed_weights" | |
return model_configuration, int4_model_dir, pt_model_name | |
# Function to download the model if not already present | |
def download_model_if_needed(model_language_value, model_id_value): | |
model_configuration, int4_model_dir, pt_model_name = get_model_path(model_language_value, model_id_value) | |
int4_weights = int4_model_dir / "openvino_model.bin" | |
if not int4_weights.exists(): | |
print(f"Downloading model {model_id_value}...") | |
# Add your download logic here (e.g., from a URL) | |
# Example: | |
# r = requests.get(model_configuration["model_url"]) | |
# with open(int4_weights, "wb") as f: | |
# f.write(r.content) | |
return int4_model_dir | |
# Load the model | |
def load_model(model_language_value, model_id_value): | |
int4_model_dir = download_model_if_needed(model_language_value, model_id_value) | |
# Load the OpenVINO model | |
ov_config = {hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): ""} | |
core = ov.Core() | |
model_dir = int4_model_dir | |
model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value] | |
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) | |
ov_model = OVModelForCausalLM.from_pretrained( | |
model_dir, | |
device=device.value, | |
ov_config=ov_config, | |
config=AutoConfig.from_pretrained(model_dir, trust_remote_code=True), | |
trust_remote_code=True | |
) | |
return tok, ov_model, model_configuration | |
# Gradio interface function for generating text responses | |
def generate_response(history, temperature, top_p, top_k, repetition_penalty, model_language_value, model_id_value): | |
tok, ov_model, model_configuration = load_model(model_language_value, model_id_value) | |
# Convert history to tokens | |
def convert_history_to_token(history): | |
# (Your history conversion logic here) | |
# Use model_configuration to determine the exact format | |
input_tokens = tok(" ".join([msg[0] for msg in history]), return_tensors="pt").input_ids | |
return input_tokens | |
input_ids = convert_history_to_token(history) | |
streamer = gr.Textbox.update() | |
# Adjust generation kwargs | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
max_new_tokens=256, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
streamer=streamer | |
) | |
# Start streaming response | |
event = Event() | |
def generate_and_signal_complete(): | |
ov_model.generate(**generate_kwargs) | |
event.set() | |
t1 = Thread(target=generate_and_signal_complete) | |
t1.start() | |
# Collect generated text | |
partial_text = "" | |
for new_text in streamer: | |
partial_text += new_text | |
history[-1][1] = partial_text | |
yield history | |
# Gradio UI components | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature") | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P") | |
top_k = gr.Slider(minimum=0, maximum=50, value=50, label="Top K") | |
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, label="Repetition Penalty") | |
# Conversation history input/output | |
history = gr.State([]) # store the conversation history | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=generate_response, | |
inputs=[ | |
history, | |
temperature, | |
top_p, | |
top_k, | |
repetition_penalty, | |
model_language, | |
model_id | |
], | |
outputs=[gr.Textbox(label="Conversation History")], | |
live=True, | |
title="OpenVINO Chatbot" | |
) | |
# Launch Gradio app | |
if __name__ == "__main__": | |
iface.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860) | |