lightmate commited on
Commit
4acb2ad
·
verified ·
1 Parent(s): d8164ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -87
app.py CHANGED
@@ -18,54 +18,13 @@ from notebook_utils import device_widget
18
  # Initialize model language options
19
  model_languages = list(SUPPORTED_LLM_MODELS)
20
 
21
- # Gradio components for selecting model language and model ID
22
- model_language = gr.Dropdown(
23
- choices=model_languages,
24
- value=model_languages[0],
25
- label="Model Language"
26
- )
27
-
28
- # Gradio dropdown for selecting model ID based on language
29
  def update_model_id(model_language_value):
30
  model_ids = list(SUPPORTED_LLM_MODELS[model_language_value])
31
  return model_ids[0], gr.update(choices=model_ids)
32
 
33
- model_id = gr.Dropdown(
34
- choices=[], # will be dynamically populated
35
- label="Model",
36
- value=None
37
- )
38
-
39
- model_language.change(update_model_id, inputs=model_language, outputs=[model_id])
40
-
41
- # Gradio checkbox for preparing INT4 model
42
- prepare_int4_model = gr.Checkbox(
43
- value=True,
44
- label="Prepare INT4 Model"
45
- )
46
-
47
- # Gradio checkbox for enabling AWQ (depends on INT4 checkbox)
48
- enable_awq = gr.Checkbox(
49
- value=False,
50
- label="Enable AWQ",
51
- visible=False
52
- )
53
-
54
- # Device selection widget (e.g., CPU or GPU)
55
- device = device_widget("CPU", exclude=["NPU"])
56
-
57
- # Model directory and setup based on selections
58
- def get_model_path(model_language_value, model_id_value):
59
- model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
60
- pt_model_id = model_configuration["model_id"]
61
- pt_model_name = model_id_value.split("-")[0]
62
- int4_model_dir = Path(model_id_value) / "INT4_compressed_weights"
63
- return model_configuration, int4_model_dir, pt_model_name
64
-
65
  # Function to download the model if not already present
66
  def download_model_if_needed(model_language_value, model_id_value):
67
  model_configuration, int4_model_dir, pt_model_name = get_model_path(model_language_value, model_id_value)
68
-
69
  int4_weights = int4_model_dir / "openvino_model.bin"
70
 
71
  if not int4_weights.exists():
@@ -75,14 +34,12 @@ def download_model_if_needed(model_language_value, model_id_value):
75
  # r = requests.get(model_configuration["model_url"])
76
  # with open(int4_weights, "wb") as f:
77
  # f.write(r.content)
78
-
79
  return int4_model_dir
80
 
81
  # Load the model
82
  def load_model(model_language_value, model_id_value):
83
  int4_model_dir = download_model_if_needed(model_language_value, model_id_value)
84
-
85
- # Load the OpenVINO model
86
  ov_config = {hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): ""}
87
  core = ov.Core()
88
 
@@ -103,18 +60,9 @@ def load_model(model_language_value, model_id_value):
103
  # Gradio interface function for generating text responses
104
  def generate_response(history, temperature, top_p, top_k, repetition_penalty, model_language_value, model_id_value):
105
  tok, ov_model, model_configuration = load_model(model_language_value, model_id_value)
106
-
107
- # Convert history to tokens
108
- def convert_history_to_token(history):
109
- # (Your history conversion logic here)
110
- # Use model_configuration to determine the exact format
111
- input_tokens = tok(" ".join([msg[0] for msg in history]), return_tensors="pt").input_ids
112
- return input_tokens
113
-
114
- input_ids = convert_history_to_token(history)
115
  streamer = gr.Textbox.update()
116
 
117
- # Adjust generation kwargs
118
  generate_kwargs = dict(
119
  input_ids=input_ids,
120
  max_new_tokens=256,
@@ -125,49 +73,72 @@ def generate_response(history, temperature, top_p, top_k, repetition_penalty, mo
125
  streamer=streamer
126
  )
127
 
128
- # Start streaming response
129
  event = Event()
130
-
131
  def generate_and_signal_complete():
132
  ov_model.generate(**generate_kwargs)
133
  event.set()
134
-
135
  t1 = Thread(target=generate_and_signal_complete)
136
  t1.start()
137
-
138
- # Collect generated text
139
  partial_text = ""
140
  for new_text in streamer:
141
  partial_text += new_text
142
  history[-1][1] = partial_text
143
  yield history
144
 
145
- # Gradio UI components
146
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature")
147
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P")
148
- top_k = gr.Slider(minimum=0, maximum=50, value=50, label="Top K")
149
- repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, label="Repetition Penalty")
150
-
151
- # Conversation history input/output
152
- history = gr.State([]) # store the conversation history
153
-
154
- # Gradio Interface
155
- iface = gr.Interface(
156
- fn=generate_response,
157
- inputs=[
158
- history,
159
- temperature,
160
- top_p,
161
- top_k,
162
- repetition_penalty,
163
- model_language,
164
- model_id
165
- ],
166
- outputs=[gr.Textbox(label="Conversation History")],
167
- live=True,
168
- title="OpenVINO Chatbot"
169
- )
170
-
171
- # Launch Gradio app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  if __name__ == "__main__":
173
- iface.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)
 
18
  # Initialize model language options
19
  model_languages = list(SUPPORTED_LLM_MODELS)
20
 
 
 
 
 
 
 
 
 
21
  def update_model_id(model_language_value):
22
  model_ids = list(SUPPORTED_LLM_MODELS[model_language_value])
23
  return model_ids[0], gr.update(choices=model_ids)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Function to download the model if not already present
26
  def download_model_if_needed(model_language_value, model_id_value):
27
  model_configuration, int4_model_dir, pt_model_name = get_model_path(model_language_value, model_id_value)
 
28
  int4_weights = int4_model_dir / "openvino_model.bin"
29
 
30
  if not int4_weights.exists():
 
34
  # r = requests.get(model_configuration["model_url"])
35
  # with open(int4_weights, "wb") as f:
36
  # f.write(r.content)
37
+
38
  return int4_model_dir
39
 
40
  # Load the model
41
  def load_model(model_language_value, model_id_value):
42
  int4_model_dir = download_model_if_needed(model_language_value, model_id_value)
 
 
43
  ov_config = {hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): ""}
44
  core = ov.Core()
45
 
 
60
  # Gradio interface function for generating text responses
61
  def generate_response(history, temperature, top_p, top_k, repetition_penalty, model_language_value, model_id_value):
62
  tok, ov_model, model_configuration = load_model(model_language_value, model_id_value)
63
+ input_ids = tok(" ".join([msg[0] for msg in history]), return_tensors="pt").input_ids
 
 
 
 
 
 
 
 
64
  streamer = gr.Textbox.update()
65
 
 
66
  generate_kwargs = dict(
67
  input_ids=input_ids,
68
  max_new_tokens=256,
 
73
  streamer=streamer
74
  )
75
 
 
76
  event = Event()
 
77
  def generate_and_signal_complete():
78
  ov_model.generate(**generate_kwargs)
79
  event.set()
80
+
81
  t1 = Thread(target=generate_and_signal_complete)
82
  t1.start()
83
+
 
84
  partial_text = ""
85
  for new_text in streamer:
86
  partial_text += new_text
87
  history[-1][1] = partial_text
88
  yield history
89
 
90
+ # Gradio UI within a Blocks context
91
+ with gr.Blocks() as iface:
92
+ model_language = gr.Dropdown(
93
+ choices=model_languages,
94
+ value=model_languages[0],
95
+ label="Model Language"
96
+ )
97
+
98
+ model_id = gr.Dropdown(
99
+ choices=[], # dynamically populated
100
+ label="Model",
101
+ value=None
102
+ )
103
+
104
+ model_language.change(update_model_id, inputs=model_language, outputs=[model_id])
105
+
106
+ prepare_int4_model = gr.Checkbox(
107
+ value=True,
108
+ label="Prepare INT4 Model"
109
+ )
110
+ enable_awq = gr.Checkbox(
111
+ value=False,
112
+ label="Enable AWQ",
113
+ visible=False
114
+ )
115
+
116
+ device = device_widget("CPU", exclude=["NPU"])
117
+
118
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature")
119
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P")
120
+ top_k = gr.Slider(minimum=0, maximum=50, value=50, label="Top K")
121
+ repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, label="Repetition Penalty")
122
+
123
+ history = gr.State([])
124
+
125
+ iface_interface = gr.Interface(
126
+ fn=generate_response,
127
+ inputs=[
128
+ history,
129
+ temperature,
130
+ top_p,
131
+ top_k,
132
+ repetition_penalty,
133
+ model_language,
134
+ model_id
135
+ ],
136
+ outputs=[gr.Textbox(label="Conversation History")],
137
+ live=True,
138
+ title="OpenVINO Chatbot"
139
+ )
140
+
141
+ iface_interface.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)
142
+
143
  if __name__ == "__main__":
144
+ iface.launch()