llm-chatbot / app.py
lightmate's picture
Update app.py
d8164ce verified
raw
history blame
5.65 kB
import os
from pathlib import Path
import requests
import shutil
import torch
from threading import Event, Thread
from transformers import AutoConfig, AutoTokenizer
from optimum.intel.openvino import OVModelForCausalLM
import openvino as ov
import openvino.properties as props
import openvino.properties.hint as hints
import openvino.properties.streams as streams
import gradio as gr
from llm_config import SUPPORTED_LLM_MODELS
from notebook_utils import device_widget
# Initialize model language options
model_languages = list(SUPPORTED_LLM_MODELS)
# Gradio components for selecting model language and model ID
model_language = gr.Dropdown(
choices=model_languages,
value=model_languages[0],
label="Model Language"
)
# Gradio dropdown for selecting model ID based on language
def update_model_id(model_language_value):
model_ids = list(SUPPORTED_LLM_MODELS[model_language_value])
return model_ids[0], gr.update(choices=model_ids)
model_id = gr.Dropdown(
choices=[], # will be dynamically populated
label="Model",
value=None
)
model_language.change(update_model_id, inputs=model_language, outputs=[model_id])
# Gradio checkbox for preparing INT4 model
prepare_int4_model = gr.Checkbox(
value=True,
label="Prepare INT4 Model"
)
# Gradio checkbox for enabling AWQ (depends on INT4 checkbox)
enable_awq = gr.Checkbox(
value=False,
label="Enable AWQ",
visible=False
)
# Device selection widget (e.g., CPU or GPU)
device = device_widget("CPU", exclude=["NPU"])
# Model directory and setup based on selections
def get_model_path(model_language_value, model_id_value):
model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
pt_model_id = model_configuration["model_id"]
pt_model_name = model_id_value.split("-")[0]
int4_model_dir = Path(model_id_value) / "INT4_compressed_weights"
return model_configuration, int4_model_dir, pt_model_name
# Function to download the model if not already present
def download_model_if_needed(model_language_value, model_id_value):
model_configuration, int4_model_dir, pt_model_name = get_model_path(model_language_value, model_id_value)
int4_weights = int4_model_dir / "openvino_model.bin"
if not int4_weights.exists():
print(f"Downloading model {model_id_value}...")
# Add your download logic here (e.g., from a URL)
# Example:
# r = requests.get(model_configuration["model_url"])
# with open(int4_weights, "wb") as f:
# f.write(r.content)
return int4_model_dir
# Load the model
def load_model(model_language_value, model_id_value):
int4_model_dir = download_model_if_needed(model_language_value, model_id_value)
# Load the OpenVINO model
ov_config = {hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): ""}
core = ov.Core()
model_dir = int4_model_dir
model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
ov_model = OVModelForCausalLM.from_pretrained(
model_dir,
device=device.value,
ov_config=ov_config,
config=AutoConfig.from_pretrained(model_dir, trust_remote_code=True),
trust_remote_code=True
)
return tok, ov_model, model_configuration
# Gradio interface function for generating text responses
def generate_response(history, temperature, top_p, top_k, repetition_penalty, model_language_value, model_id_value):
tok, ov_model, model_configuration = load_model(model_language_value, model_id_value)
# Convert history to tokens
def convert_history_to_token(history):
# (Your history conversion logic here)
# Use model_configuration to determine the exact format
input_tokens = tok(" ".join([msg[0] for msg in history]), return_tensors="pt").input_ids
return input_tokens
input_ids = convert_history_to_token(history)
streamer = gr.Textbox.update()
# Adjust generation kwargs
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=256,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
streamer=streamer
)
# Start streaming response
event = Event()
def generate_and_signal_complete():
ov_model.generate(**generate_kwargs)
event.set()
t1 = Thread(target=generate_and_signal_complete)
t1.start()
# Collect generated text
partial_text = ""
for new_text in streamer:
partial_text += new_text
history[-1][1] = partial_text
yield history
# Gradio UI components
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P")
top_k = gr.Slider(minimum=0, maximum=50, value=50, label="Top K")
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, label="Repetition Penalty")
# Conversation history input/output
history = gr.State([]) # store the conversation history
# Gradio Interface
iface = gr.Interface(
fn=generate_response,
inputs=[
history,
temperature,
top_p,
top_k,
repetition_penalty,
model_language,
model_id
],
outputs=[gr.Textbox(label="Conversation History")],
live=True,
title="OpenVINO Chatbot"
)
# Launch Gradio app
if __name__ == "__main__":
iface.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)