lightmate commited on
Commit
d8164ce
1 Parent(s): e9d91d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -113
app.py CHANGED
@@ -1,158 +1,173 @@
1
  import os
2
- import torch
3
- import gradio as gr
4
  from pathlib import Path
 
 
 
 
5
  from transformers import AutoConfig, AutoTokenizer
6
  from optimum.intel.openvino import OVModelForCausalLM
7
- from typing import List, Tuple
8
- from threading import Event, Thread
9
- from gradio_helper import make_demo # Your helper function for Gradio demo
10
- from llm_config import SUPPORTED_LLM_MODELS # Model configuration
11
- from notebook_utils import device_widget # Device selection utility
12
  import openvino as ov
13
  import openvino.properties as props
14
  import openvino.properties.hint as hints
15
  import openvino.properties.streams as streams
16
- import requests
17
 
18
- # Define the model loading function (same as in your notebook)
19
- def convert_to_int4(model_id, model_configuration, enable_awq=False):
20
- compression_configs = {
21
- "qwen2.5-0.5b-instruct": {"sym": True, "group_size": 128, "ratio": 1.0},
22
- "default": {"sym": False, "group_size": 128, "ratio": 0.8},
23
- }
24
- model_compression_params = compression_configs.get(model_id, compression_configs["default"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Example conversion logic
27
- int4_model_dir = Path(model_id) / "INT4_compressed_weights"
28
- if (int4_model_dir / "openvino_model.xml").exists():
29
- return int4_model_dir
30
- remote_code = model_configuration.get("remote_code", False)
31
- export_command_base = f"optimum-cli export openvino --model {model_configuration['model_id']} --task text-generation-with-past --weight-format int4"
32
- int4_compression_args = f" --group-size {model_compression_params['group_size']} --ratio {model_compression_params['ratio']}"
33
- if model_compression_params["sym"]:
34
- int4_compression_args += " --sym"
35
- if enable_awq:
36
- int4_compression_args += " --awq --dataset wikitext2 --num-samples 128"
37
- export_command_base += int4_compression_args
38
- if remote_code:
39
- export_command_base += " --trust-remote-code"
40
- export_command = export_command_base + f" {str(int4_model_dir)}"
41
 
42
- # Execute export command (shell command)
43
- os.system(export_command)
 
 
 
 
 
 
44
  return int4_model_dir
45
 
46
- # Model and tokenizer loading
47
- def load_model(model_dir, device):
 
 
 
48
  ov_config = {hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): ""}
49
  core = ov.Core()
50
- model_name = model_configuration["model_id"]
 
 
 
51
  tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
52
-
53
  ov_model = OVModelForCausalLM.from_pretrained(
54
  model_dir,
55
- device=device,
56
  ov_config=ov_config,
57
  config=AutoConfig.from_pretrained(model_dir, trust_remote_code=True),
58
- trust_remote_code=True,
59
  )
 
 
60
 
61
- return ov_model, tok
62
-
63
- # Gradio Interface for Bot interaction
64
- def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
65
- input_ids = convert_history_to_token(history)
66
- if input_ids.shape[1] > 2000:
67
- history = [history[-1]] # Limit input size
68
- input_ids = convert_history_to_token(history)
69
 
70
- streamer = TextIteratorStreamer(tok, timeout=3600.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
71
 
 
 
 
 
72
  generate_kwargs = dict(
73
  input_ids=input_ids,
74
  max_new_tokens=256,
75
  temperature=temperature,
76
- do_sample=temperature > 0.0,
77
  top_p=top_p,
78
  top_k=top_k,
79
  repetition_penalty=repetition_penalty,
80
- streamer=streamer,
81
  )
 
 
 
82
 
83
- # Function to generate response in a separate thread
84
  def generate_and_signal_complete():
85
  ov_model.generate(**generate_kwargs)
86
- stream_complete.set()
87
 
88
  t1 = Thread(target=generate_and_signal_complete)
89
  t1.start()
90
 
91
- # Process partial text and return updated history
92
  partial_text = ""
93
  for new_text in streamer:
94
- partial_text = text_processor(partial_text, new_text)
95
  history[-1][1] = partial_text
96
  yield history
97
 
98
- # Define a Gradio interface for user interaction
99
- def create_gradio_interface():
100
- # Dropdown for selecting model language and model ID
101
- model_language = list(SUPPORTED_LLM_MODELS.keys()) # List of model languages
102
- model_id = gr.Dropdown(choices=model_language, value=model_language[0], label="Model Language")
103
-
104
- # Once model language is selected, show the respective model IDs
105
- def update_model_ids(model_language):
106
- model_ids = list(SUPPORTED_LLM_MODELS[model_language].keys())
107
- return gr.Dropdown.update(choices=model_ids, value=model_ids[0])
108
-
109
- model_id_selector = gr.Dropdown(choices=model_language, value=model_language[0], label="Model ID")
110
-
111
- # Set up a checkbox for enabling AWQ compression
112
- enable_awq = gr.Checkbox(value=False, label="Enable AWQ for Compression")
113
-
114
- # Initialize model selection based on language and ID
115
- def load_model_on_select(model_language, model_id, enable_awq):
116
- model_configuration = SUPPORTED_LLM_MODELS[model_language][model_id]
117
- int4_model_dir = convert_to_int4(model_id, model_configuration, enable_awq)
118
-
119
- # Load the model and tokenizer
120
- device = device_widget("CPU") # or any device you want to use
121
- ov_model, tok = load_model(int4_model_dir, device)
122
-
123
- # Return the loaded model and tokenizer
124
- return ov_model, tok
125
-
126
- # Create the Gradio chatbot interface
127
- chatbot = gr.Chatbot()
128
-
129
- # Parameters for bot generation
130
- temperature = gr.Slider(minimum=0, maximum=1, step=0.1, label="Temperature", value=0.7)
131
- top_p = gr.Slider(minimum=0, maximum=1, step=0.1, label="Top-p", value=0.9)
132
- top_k = gr.Slider(minimum=0, maximum=50, step=1, label="Top-k", value=50)
133
- repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.1, label="Repetition Penalty", value=1.0)
134
-
135
- with gr.Blocks() as demo:
136
- # Create the Gradio components and add them to the Blocks context
137
- model_id_selector.change(update_model_ids, inputs=model_language, outputs=model_id_selector)
138
- load_button = gr.Button("Load Model")
139
- load_button.click(load_model_on_select, inputs=[model_language, model_id, enable_awq], outputs=[gr.Textbox(label="Model Status")])
140
-
141
- # Set up the chatbot UI with all the required components
142
- gr.Row([model_id_selector, enable_awq]) # Arrange the dropdowns and checkbox in a row
143
- gr.Row([load_button]) # Add the button below the inputs
144
- gr.Row([chatbot]) # Add the chatbot output
145
-
146
- # Parameters for generation
147
- gr.Row([temperature, top_p, top_k, repetition_penalty]) # Add sliders in a row
148
-
149
- # Define bot function and run the interface
150
- demo.queue() # This is used to queue inputs and outputs, handling concurrent generation calls
151
- demo.launch(debug=True, share=True) # For public access
152
-
153
- return demo
154
-
155
- # Run the Gradio app
156
  if __name__ == "__main__":
157
- app = create_gradio_interface()
158
- app.launch(debug=True, share=True) # share=True for public access
 
1
  import os
 
 
2
  from pathlib import Path
3
+ import requests
4
+ import shutil
5
+ import torch
6
+ from threading import Event, Thread
7
  from transformers import AutoConfig, AutoTokenizer
8
  from optimum.intel.openvino import OVModelForCausalLM
 
 
 
 
 
9
  import openvino as ov
10
  import openvino.properties as props
11
  import openvino.properties.hint as hints
12
  import openvino.properties.streams as streams
13
+ import gradio as gr
14
 
15
+ from llm_config import SUPPORTED_LLM_MODELS
16
+ from notebook_utils import device_widget
17
+
18
+ # Initialize model language options
19
+ model_languages = list(SUPPORTED_LLM_MODELS)
20
+
21
+ # Gradio components for selecting model language and model ID
22
+ model_language = gr.Dropdown(
23
+ choices=model_languages,
24
+ value=model_languages[0],
25
+ label="Model Language"
26
+ )
27
+
28
+ # Gradio dropdown for selecting model ID based on language
29
+ def update_model_id(model_language_value):
30
+ model_ids = list(SUPPORTED_LLM_MODELS[model_language_value])
31
+ return model_ids[0], gr.update(choices=model_ids)
32
+
33
+ model_id = gr.Dropdown(
34
+ choices=[], # will be dynamically populated
35
+ label="Model",
36
+ value=None
37
+ )
38
+
39
+ model_language.change(update_model_id, inputs=model_language, outputs=[model_id])
40
+
41
+ # Gradio checkbox for preparing INT4 model
42
+ prepare_int4_model = gr.Checkbox(
43
+ value=True,
44
+ label="Prepare INT4 Model"
45
+ )
46
+
47
+ # Gradio checkbox for enabling AWQ (depends on INT4 checkbox)
48
+ enable_awq = gr.Checkbox(
49
+ value=False,
50
+ label="Enable AWQ",
51
+ visible=False
52
+ )
53
+
54
+ # Device selection widget (e.g., CPU or GPU)
55
+ device = device_widget("CPU", exclude=["NPU"])
56
+
57
+ # Model directory and setup based on selections
58
+ def get_model_path(model_language_value, model_id_value):
59
+ model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
60
+ pt_model_id = model_configuration["model_id"]
61
+ pt_model_name = model_id_value.split("-")[0]
62
+ int4_model_dir = Path(model_id_value) / "INT4_compressed_weights"
63
+ return model_configuration, int4_model_dir, pt_model_name
64
+
65
+ # Function to download the model if not already present
66
+ def download_model_if_needed(model_language_value, model_id_value):
67
+ model_configuration, int4_model_dir, pt_model_name = get_model_path(model_language_value, model_id_value)
68
 
69
+ int4_weights = int4_model_dir / "openvino_model.bin"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ if not int4_weights.exists():
72
+ print(f"Downloading model {model_id_value}...")
73
+ # Add your download logic here (e.g., from a URL)
74
+ # Example:
75
+ # r = requests.get(model_configuration["model_url"])
76
+ # with open(int4_weights, "wb") as f:
77
+ # f.write(r.content)
78
+
79
  return int4_model_dir
80
 
81
+ # Load the model
82
+ def load_model(model_language_value, model_id_value):
83
+ int4_model_dir = download_model_if_needed(model_language_value, model_id_value)
84
+
85
+ # Load the OpenVINO model
86
  ov_config = {hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): ""}
87
  core = ov.Core()
88
+
89
+ model_dir = int4_model_dir
90
+ model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
91
+
92
  tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
 
93
  ov_model = OVModelForCausalLM.from_pretrained(
94
  model_dir,
95
+ device=device.value,
96
  ov_config=ov_config,
97
  config=AutoConfig.from_pretrained(model_dir, trust_remote_code=True),
98
+ trust_remote_code=True
99
  )
100
+
101
+ return tok, ov_model, model_configuration
102
 
103
+ # Gradio interface function for generating text responses
104
+ def generate_response(history, temperature, top_p, top_k, repetition_penalty, model_language_value, model_id_value):
105
+ tok, ov_model, model_configuration = load_model(model_language_value, model_id_value)
 
 
 
 
 
106
 
107
+ # Convert history to tokens
108
+ def convert_history_to_token(history):
109
+ # (Your history conversion logic here)
110
+ # Use model_configuration to determine the exact format
111
+ input_tokens = tok(" ".join([msg[0] for msg in history]), return_tensors="pt").input_ids
112
+ return input_tokens
113
 
114
+ input_ids = convert_history_to_token(history)
115
+ streamer = gr.Textbox.update()
116
+
117
+ # Adjust generation kwargs
118
  generate_kwargs = dict(
119
  input_ids=input_ids,
120
  max_new_tokens=256,
121
  temperature=temperature,
 
122
  top_p=top_p,
123
  top_k=top_k,
124
  repetition_penalty=repetition_penalty,
125
+ streamer=streamer
126
  )
127
+
128
+ # Start streaming response
129
+ event = Event()
130
 
 
131
  def generate_and_signal_complete():
132
  ov_model.generate(**generate_kwargs)
133
+ event.set()
134
 
135
  t1 = Thread(target=generate_and_signal_complete)
136
  t1.start()
137
 
138
+ # Collect generated text
139
  partial_text = ""
140
  for new_text in streamer:
141
+ partial_text += new_text
142
  history[-1][1] = partial_text
143
  yield history
144
 
145
+ # Gradio UI components
146
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature")
147
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P")
148
+ top_k = gr.Slider(minimum=0, maximum=50, value=50, label="Top K")
149
+ repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, label="Repetition Penalty")
150
+
151
+ # Conversation history input/output
152
+ history = gr.State([]) # store the conversation history
153
+
154
+ # Gradio Interface
155
+ iface = gr.Interface(
156
+ fn=generate_response,
157
+ inputs=[
158
+ history,
159
+ temperature,
160
+ top_p,
161
+ top_k,
162
+ repetition_penalty,
163
+ model_language,
164
+ model_id
165
+ ],
166
+ outputs=[gr.Textbox(label="Conversation History")],
167
+ live=True,
168
+ title="OpenVINO Chatbot"
169
+ )
170
+
171
+ # Launch Gradio app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  if __name__ == "__main__":
173
+ iface.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)