LukasHug commited on
Commit
e992a5c
β€’
1 Parent(s): 83baad4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +532 -66
app.py CHANGED
@@ -1,57 +1,521 @@
1
- import sys
2
- import os
3
  import argparse
 
 
 
 
 
4
  import time
5
- import subprocess
 
 
6
  import spaces
7
- import gradio_web_server as gws
8
-
9
- # Execute the pip install command with additional options
10
- # subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
11
-
12
-
13
- def start_controller():
14
- print("Starting the controller")
15
- controller_command = [
16
- sys.executable,
17
- "-m",
18
- "llava.serve.controller",
19
- "--host",
20
- "0.0.0.0",
21
- "--port",
22
- "10000",
23
- ]
24
- print(controller_command)
25
- return subprocess.Popen(controller_command)
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  @spaces.GPU
28
- def start_worker(model_path: str, model_name: str, bits=16, device=0):
29
- print(f"Starting the model worker for the model {model_path}")
30
- # model_name = model_path.strip("/").split("/")[-1]
31
- device = f"cuda:{device}" if isinstance(device, int) else device
32
- assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
33
- if bits != 16:
34
- model_name += f"-{bits}bit"
35
- worker_command = [
36
- sys.executable,
37
- "-m",
38
- "llava.serve.model_worker",
39
- "--host",
40
- "0.0.0.0",
41
- "--controller",
42
- "http://localhost:10000",
43
- "--model-path",
44
- model_path,
45
- "--model-name",
46
- model_name,
47
- "--use-flash-attn",
48
- '--device',
49
- device
50
- ]
51
- if bits != 16:
52
- worker_command += [f"--load-{bits}bit"]
53
- print(worker_command)
54
- return subprocess.Popen(worker_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  if __name__ == "__main__":
@@ -64,10 +528,10 @@ if __name__ == "__main__":
64
  parser.add_argument("--share", action="store_true")
65
  parser.add_argument("--moderate", action="store_true")
66
  parser.add_argument("--embed", action="store_true")
67
- gws.args = parser.parse_args()
68
- gws.models = []
69
 
70
- gws.title_markdown += """
71
 
72
  ONLY WORKS WITH GPU!
73
 
@@ -77,9 +541,9 @@ Set the environment variable `model` to change the model:
77
  ['AIML-TUDA/LlavaGuard-34B'](https://huggingface.co/AIML-TUDA/LlavaGuard-34B),
78
  """
79
  # set_up_env_and_token(read=True)
80
- print(f"args: {gws.args}")
81
  # set the huggingface login token
82
- controller_proc = start_controller()
83
  concurrency_count = int(os.getenv("concurrency_count", 5))
84
  api_key = os.getenv("token")
85
  if api_key:
@@ -89,40 +553,42 @@ Set the environment variable `model` to change the model:
89
  if '/workspace' not in sys.path:
90
  sys.path.append('/workspace')
91
  from llavaguard.hf_utils import set_up_env_and_token
92
- set_up_env_and_token(read=True, write=False)
 
93
 
94
  models = [
95
  'LukasHug/LlavaGuard-7B-hf',
96
  'LukasHug/LlavaGuard-13B-hf',
97
- 'LukasHug/LlavaGuard-34B-hf',]
98
  bits = int(os.getenv("bits", 16))
99
  model = os.getenv("model", models[0])
100
  available_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0")
101
  model_path, model_name = model, model.split("/")[1]
 
 
102
 
103
- worker_proc = start_worker(model_path, model_name, bits=bits)
 
104
 
 
105
 
106
  # Wait for worker and controller to start
107
- time.sleep(50)
108
 
109
  exit_status = 0
110
  try:
111
- demo = gws.build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
112
  demo.queue(
113
  status_update_rate=10,
114
  api_open=False
115
  ).launch(
116
- server_name=gws.args.host,
117
- server_port=gws.args.port,
118
- share=gws.args.share
119
  )
120
 
121
  except Exception as e:
122
  print(e)
123
  exit_status = 1
124
  finally:
125
- worker_proc.kill()
126
- controller_proc.kill()
127
-
128
- sys.exit(exit_status)
 
 
 
1
  import argparse
2
+ import datetime
3
+ import hashlib
4
+ import json
5
+ import os
6
+ import sys
7
  import time
8
+ import warnings
9
+
10
+ import gradio as gr
11
  import spaces
12
+ import torch
13
+
14
+ from llava.constants import IMAGE_TOKEN_INDEX
15
+ from llava.constants import LOGDIR
16
+ from llava.conversation import (default_conversation, conv_templates)
17
+ from llava.mm_utils import KeywordsStoppingCriteria, tokenizer_image_token
18
+ from llava.model.builder import load_pretrained_model
19
+ from llava.utils import (build_logger, violates_moderation, moderation_msg)
20
+ from taxonomy import wrap_taxonomy, default_taxonomy
21
+
22
+
23
+ def clear_conv(conv):
24
+ conv.messages = []
25
+ return conv
26
+
27
+
28
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
29
+
30
+ headers = {"User-Agent": "LLaVA Client"}
31
+
32
+ no_change_btn = gr.Button()
33
+ enable_btn = gr.Button(interactive=True)
34
+ disable_btn = gr.Button(interactive=False)
35
+
36
+ priority = {
37
+ "LlavaGuard-7B": "aaaaaaa",
38
+ "LlavaGuard-13B": "aaaaaab",
39
+ "LlavaGuard-34B": "aaaaaac",
40
+ }
41
+
42
 
43
  @spaces.GPU
44
+ def run_llava(prompt, pil_image):
45
+ image_size = pil_image.size
46
+ image_tensor = image_processor.preprocess(pil_image, return_tensors='pt')['pixel_values'].half().cuda()
47
+ # images_tensor = load_images(images, image_processor)
48
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
49
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
50
+ input_ids = input_ids.unsqueeze(0).cuda()
51
+ with torch.inference_mode():
52
+ output_ids = model.generate(
53
+ input_ids,
54
+ images=image_tensor,
55
+ image_sizes=[image_size],
56
+ do_sample=True,
57
+ temperature=0.2,
58
+ top_p=0.95,
59
+ top_k=50,
60
+ num_beams=2,
61
+ max_new_tokens=1024,
62
+ use_cache=True,
63
+ stopping_criteria=[KeywordsStoppingCriteria(['}'], tokenizer, input_ids)]
64
+ )
65
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
66
+
67
+ return outputs[0].strip()
68
+
69
+
70
+ def get_conv_log_filename():
71
+ t = datetime.datetime.now()
72
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
73
+ return name
74
+
75
+
76
+ def get_model_list():
77
+ # ret = requests.post(args.controller_url + "/refresh_all_workers")
78
+ # assert ret.status_code == 200
79
+ # ret = requests.post(args.controller_url + "/list_models")
80
+ # logger.info(f"get_model_list: {ret.json()}")
81
+ # models = ret.json()["models"]
82
+ # models.sort(key=lambda x: priority.get(x, x))
83
+ # logger.info(f"Models: {models}")
84
+ models = [
85
+ 'LukasHug/LlavaGuard-7B-hf',
86
+ 'LukasHug/LlavaGuard-13B-hf',
87
+ 'LukasHug/LlavaGuard-34B-hf', ][:1]
88
+ return models
89
+
90
+
91
+ get_window_url_params = """
92
+ function() {
93
+ const params = new URLSearchParams(window.location.search);
94
+ url_params = Object.fromEntries(params);
95
+ console.log(url_params);
96
+ return url_params;
97
+ }
98
+ """
99
+
100
+
101
+ def load_demo(url_params, request: gr.Request):
102
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
103
+
104
+ dropdown_update = gr.Dropdown(visible=True)
105
+ if "model" in url_params:
106
+ model = url_params["model"]
107
+ if model in models:
108
+ dropdown_update = gr.Dropdown(value=model, visible=True)
109
+
110
+ state = default_conversation.copy()
111
+ return state, dropdown_update
112
+
113
+
114
+ def load_demo_refresh_model_list(request: gr.Request):
115
+ logger.info(f"load_demo. ip: {request.client.host}")
116
+ models = get_model_list()
117
+ state = default_conversation.copy()
118
+ dropdown_update = gr.Dropdown(
119
+ choices=models,
120
+ value=models[0] if len(models) > 0 else ""
121
+ )
122
+ return state, dropdown_update
123
+
124
+
125
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
126
+ with open(get_conv_log_filename(), "a") as fout:
127
+ data = {
128
+ "tstamp": round(time.time(), 4),
129
+ "type": vote_type,
130
+ "model": model_selector,
131
+ "state": state.dict(),
132
+ "ip": request.client.host,
133
+ }
134
+ fout.write(json.dumps(data) + "\n")
135
+
136
+
137
+ def upvote_last_response(state, model_selector, request: gr.Request):
138
+ logger.info(f"upvote. ip: {request.client.host}")
139
+ vote_last_response(state, "upvote", model_selector, request)
140
+ return ("",) + (disable_btn,) * 3
141
+
142
+
143
+ def downvote_last_response(state, model_selector, request: gr.Request):
144
+ logger.info(f"downvote. ip: {request.client.host}")
145
+ vote_last_response(state, "downvote", model_selector, request)
146
+ return ("",) + (disable_btn,) * 3
147
+
148
+
149
+ def flag_last_response(state, model_selector, request: gr.Request):
150
+ logger.info(f"flag. ip: {request.client.host}")
151
+ vote_last_response(state, "flag", model_selector, request)
152
+ return ("",) + (disable_btn,) * 3
153
+
154
+
155
+ def regenerate(state, image_process_mode, request: gr.Request):
156
+ logger.info(f"regenerate. ip: {request.client.host}")
157
+ state.messages[-1][-1] = None
158
+ prev_human_msg = state.messages[-2]
159
+ if type(prev_human_msg[1]) in (tuple, list):
160
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
161
+ state.skip_next = False
162
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
163
+
164
+
165
+ def clear_history(request: gr.Request):
166
+ logger.info(f"clear_history. ip: {request.client.host}")
167
+ state = default_conversation.copy()
168
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
169
+
170
+
171
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
172
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
173
+ if len(text) <= 0 or image is None:
174
+ state.skip_next = True
175
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
176
+ if args.moderate:
177
+ flagged = violates_moderation(text)
178
+ if flagged:
179
+ state.skip_next = True
180
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
181
+ no_change_btn,) * 5
182
+
183
+ text = wrap_taxonomy(text)
184
+ if image is not None:
185
+ text = text # Hard cut-off for images
186
+ if '<image>' not in text:
187
+ # text = '<Image><image></Image>' + text
188
+ text = text + '\n<image>'
189
+ text = (text, image, image_process_mode)
190
+ state = default_conversation.copy()
191
+ state = clear_conv(state)
192
+ state.append_message(state.roles[0], text)
193
+ state.append_message(state.roles[1], None)
194
+ state.skip_next = False
195
+ return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
196
+
197
+
198
+ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
199
+ start_tstamp = time.time()
200
+ model_name = model_selector
201
+
202
+ if state.skip_next:
203
+ # This generate call is skipped due to invalid inputs
204
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
205
+ return
206
+
207
+ if len(state.messages) == state.offset + 2:
208
+ # First round of conversation
209
+ if "llava" in model_name.lower():
210
+ if 'llama-2' in model_name.lower():
211
+ template_name = "llava_llama_2"
212
+ elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
213
+ if 'orca' in model_name.lower():
214
+ template_name = "mistral_orca"
215
+ elif 'hermes' in model_name.lower():
216
+ template_name = "chatml_direct"
217
+ else:
218
+ template_name = "mistral_instruct"
219
+ elif 'llava-v1.6-34b' in model_name.lower():
220
+ template_name = "chatml_direct"
221
+ elif "v1" in model_name.lower():
222
+ if 'mmtag' in model_name.lower():
223
+ template_name = "v1_mmtag"
224
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
225
+ template_name = "v1_mmtag"
226
+ else:
227
+ template_name = "llava_v1"
228
+ elif "mpt" in model_name.lower():
229
+ template_name = "mpt"
230
+ else:
231
+ if 'mmtag' in model_name.lower():
232
+ template_name = "v0_mmtag"
233
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
234
+ template_name = "v0_mmtag"
235
+ else:
236
+ template_name = "llava_v0"
237
+ elif "mpt" in model_name:
238
+ template_name = "mpt_text"
239
+ elif "llama-2" in model_name:
240
+ template_name = "llama_2"
241
+ else:
242
+ template_name = "vicuna_v1"
243
+ new_state = conv_templates[template_name].copy()
244
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
245
+ new_state.append_message(new_state.roles[1], None)
246
+ state = new_state
247
+
248
+ # Query worker address
249
+ # controller_url = args.controller_url
250
+ # ret = requests.post(controller_url + "/get_worker_address",
251
+ # json={"model": model_name})
252
+ # worker_addr = ret.json()["address"]
253
+ # logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
254
+
255
+ # No available worker
256
+ # if worker_addr == "":
257
+ # state.messages[-1][-1] = server_error_msg
258
+ # yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
259
+ # return
260
+
261
+ # Construct prompt
262
+ prompt = state.get_prompt()
263
+
264
+ all_images = state.get_images(return_pil=True)
265
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
266
+ for image, hash in zip(all_images, all_image_hash):
267
+ t = datetime.datetime.now()
268
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
269
+ if not os.path.isfile(filename):
270
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
271
+ image.save(filename)
272
+
273
+ output = run_llava(prompt, all_images[0])
274
+
275
+ state.messages[-1][-1] = output
276
+
277
+ # Make requests
278
+ # pload = {
279
+ # "model": model_name,
280
+ # "prompt": prompt,
281
+ # "temperature": float(temperature),
282
+ # "top_p": float(top_p),
283
+ # # "num_beams": 2,
284
+ # # "top_k": 50,
285
+ # "max_new_tokens": min(int(max_new_tokens), 1536),
286
+ # "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
287
+ # "images": f'List of {len(state.get_images())} images: {all_image_hash}',
288
+ # }
289
+ # logger.info(f"==== request ====\n{pload}")
290
+ #
291
+ # pload['images'] = state.get_images()
292
+
293
+ # state.messages[-1][-1] = "β–Œ"
294
+
295
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
296
+
297
+ # try:
298
+ # # Stream output
299
+ # response = requests.post(worker_addr + "/worker_generate_stream",
300
+ # headers=headers, json=pload, stream=True, timeout=10)
301
+ # for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
302
+ # if chunk:
303
+ # data = json.loads(chunk.decode())
304
+ # if data["error_code"] == 0:
305
+ # output = data["text"][len(prompt):].strip()
306
+ # state.messages[-1][-1] = output + "β–Œ"
307
+ # yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
308
+ # else:
309
+ # output = data["text"] + f" (error_code: {data['error_code']})"
310
+ # state.messages[-1][-1] = output
311
+ # yield (state, state.to_gradio_chatbot()) + (
312
+ # disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
313
+ # return
314
+ # time.sleep(0.03)
315
+ # except requests.exceptions.RequestException as e:
316
+ # state.messages[-1][-1] = server_error_msg
317
+ # yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
318
+ # return
319
+ #
320
+ # state.messages[-1][-1] = state.messages[-1][-1][:-1]
321
+ # yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
322
+
323
+ finish_tstamp = time.time()
324
+ logger.info(f"{output}")
325
+
326
+ with open(get_conv_log_filename(), "a") as fout:
327
+ data = {
328
+ "tstamp": round(finish_tstamp, 4),
329
+ "type": "chat",
330
+ "model": model_name,
331
+ "start": round(start_tstamp, 4),
332
+ "finish": round(finish_tstamp, 4),
333
+ "state": state.dict(),
334
+ "images": all_image_hash,
335
+ "ip": request.client.host,
336
+ }
337
+ fout.write(json.dumps(data) + "\n")
338
+
339
+
340
+ title_markdown = ("""
341
+ # LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment
342
+ [[Project Page](https://ml-research.github.io/human-centered-genai/projects/llavaguard/index.html)]
343
+ [[Code](https://github.com/ml-research/LlavaGuard)]
344
+ [[Model](https://huggingface.co/collections/AIML-TUDA/llavaguard-665b42e89803408ee8ec1086)]
345
+ [[Dataset](https://huggingface.co/datasets/aiml-tuda/llavaguard)]
346
+ [[LavaGuard](https://arxiv.org/abs/2406.05113)]
347
+ """)
348
+
349
+ tos_markdown = ("""
350
+ ### Terms of use
351
+ By using this service, users are required to agree to the following terms:
352
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
353
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
354
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
355
+ """)
356
+
357
+ learn_more_markdown = ("""
358
+ ### License
359
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
360
+ """)
361
+
362
+ block_css = """
363
+
364
+ #buttons button {
365
+ min-width: min(120px,100%);
366
+ }
367
+
368
+ """
369
+
370
+ taxonomies = ["Default", "Modified w/ O1 non-violating", "Default message 3"]
371
+
372
+
373
+ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
374
+ with gr.Accordion("Safety Risk Taxonomy", open=False) as accordion:
375
+ textbox = gr.Textbox(
376
+ label="Safety Risk Taxonomy",
377
+ show_label=True,
378
+ placeholder="Enter your safety policy here",
379
+ container=True,
380
+ value=default_taxonomy,
381
+ lines=50)
382
+ with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
383
+ state = gr.State()
384
+
385
+ if not embed_mode:
386
+ gr.Markdown(title_markdown)
387
+
388
+ with gr.Row():
389
+ with gr.Column(scale=3):
390
+ with gr.Row(elem_id="model_selector_row"):
391
+ model_selector = gr.Dropdown(
392
+ choices=models,
393
+ value=models[0] if len(models) > 0 else "",
394
+ interactive=True,
395
+ show_label=False,
396
+ container=False)
397
+
398
+ imagebox = gr.Image(type="pil", label="Image", container=False)
399
+ image_process_mode = gr.Radio(
400
+ ["Crop", "Resize", "Pad", "Default"],
401
+ value="Default",
402
+ label="Preprocess for non-square image", visible=False)
403
+
404
+ if cur_dir is None:
405
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
406
+ gr.Examples(examples=[
407
+ [f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6)
408
+ ], inputs=imagebox)
409
+
410
+ with gr.Accordion("Parameters", open=False) as parameter_row:
411
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
412
+ label="Temperature", )
413
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P", )
414
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
415
+ label="Max output tokens", )
416
+
417
+ with gr.Column(scale=8):
418
+ chatbot = gr.Chatbot(
419
+ elem_id="chatbot",
420
+ label="LLavaGuard Safety Assessment",
421
+ height=650,
422
+ layout="panel",
423
+ )
424
+ with gr.Row():
425
+ with gr.Column(scale=8):
426
+ textbox.render()
427
+ with gr.Column(scale=1, min_width=50):
428
+ submit_btn = gr.Button(value="Send", variant="primary")
429
+ with gr.Row(elem_id="buttons") as button_row:
430
+ upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
431
+ downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
432
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
433
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
434
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
435
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
436
+
437
+ if not embed_mode:
438
+ gr.Markdown(tos_markdown)
439
+ gr.Markdown(learn_more_markdown)
440
+ url_params = gr.JSON(visible=False)
441
+
442
+ # Register listeners
443
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
444
+ upvote_btn.click(
445
+ upvote_last_response,
446
+ [state, model_selector],
447
+ [textbox, upvote_btn, downvote_btn, flag_btn]
448
+ )
449
+ downvote_btn.click(
450
+ downvote_last_response,
451
+ [state, model_selector],
452
+ [textbox, upvote_btn, downvote_btn, flag_btn]
453
+ )
454
+ flag_btn.click(
455
+ flag_last_response,
456
+ [state, model_selector],
457
+ [textbox, upvote_btn, downvote_btn, flag_btn]
458
+ )
459
+
460
+ regenerate_btn.click(
461
+ regenerate,
462
+ [state, image_process_mode],
463
+ [state, chatbot, textbox, imagebox] + btn_list
464
+ ).then(
465
+ llava_bot,
466
+ [state, model_selector, temperature, top_p, max_output_tokens],
467
+ [state, chatbot] + btn_list,
468
+ concurrency_limit=concurrency_count
469
+ )
470
+
471
+ clear_btn.click(
472
+ clear_history,
473
+ None,
474
+ [state, chatbot, textbox, imagebox] + btn_list,
475
+ queue=False
476
+ )
477
+
478
+ textbox.submit(
479
+ add_text,
480
+ [state, textbox, imagebox, image_process_mode],
481
+ [state, chatbot, textbox, imagebox] + btn_list,
482
+ queue=False
483
+ ).then(
484
+ llava_bot,
485
+ [state, model_selector, temperature, top_p, max_output_tokens],
486
+ [state, chatbot] + btn_list,
487
+ concurrency_limit=concurrency_count
488
+ )
489
+
490
+ submit_btn.click(
491
+ add_text,
492
+ [state, textbox, imagebox, image_process_mode],
493
+ [state, chatbot, textbox, imagebox] + btn_list
494
+ ).then(
495
+ llava_bot,
496
+ [state, model_selector, temperature, top_p, max_output_tokens],
497
+ [state, chatbot] + btn_list,
498
+ concurrency_limit=concurrency_count
499
+ )
500
+
501
+ if args.model_list_mode == "once":
502
+ demo.load(
503
+ load_demo,
504
+ [url_params],
505
+ [state, model_selector],
506
+ js=get_window_url_params
507
+ )
508
+ elif args.model_list_mode == "reload":
509
+ demo.load(
510
+ load_demo_refresh_model_list,
511
+ None,
512
+ [state, model_selector],
513
+ queue=False
514
+ )
515
+ else:
516
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
517
+
518
+ return demo
519
 
520
 
521
  if __name__ == "__main__":
 
528
  parser.add_argument("--share", action="store_true")
529
  parser.add_argument("--moderate", action="store_true")
530
  parser.add_argument("--embed", action="store_true")
531
+ args = parser.parse_args()
532
+ models = []
533
 
534
+ title_markdown += """
535
 
536
  ONLY WORKS WITH GPU!
537
 
 
541
  ['AIML-TUDA/LlavaGuard-34B'](https://huggingface.co/AIML-TUDA/LlavaGuard-34B),
542
  """
543
  # set_up_env_and_token(read=True)
544
+ print(f"args: {args}")
545
  # set the huggingface login token
546
+ # controller_proc = start_controller()
547
  concurrency_count = int(os.getenv("concurrency_count", 5))
548
  api_key = os.getenv("token")
549
  if api_key:
 
553
  if '/workspace' not in sys.path:
554
  sys.path.append('/workspace')
555
  from llavaguard.hf_utils import set_up_env_and_token
556
+
557
+ api_key = set_up_env_and_token(read=True, write=False)
558
 
559
  models = [
560
  'LukasHug/LlavaGuard-7B-hf',
561
  'LukasHug/LlavaGuard-13B-hf',
562
+ 'LukasHug/LlavaGuard-34B-hf', ]
563
  bits = int(os.getenv("bits", 16))
564
  model = os.getenv("model", models[0])
565
  available_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0")
566
  model_path, model_name = model, model.split("/")[1]
567
+ # model_path = '/common-repos/LlavaGuard/models/LlavaGuard-v1.1-7b-full/smid_and_crawled_v2_with_augmented_policies/json-v12/llava'
568
+
569
 
570
+ print(f"Loading model {model_path}")
571
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, token=api_key)
572
 
573
+ model.config.tokenizer_model_max_length = 2048 * 2
574
 
575
  # Wait for worker and controller to start
576
+ # time.sleep(10)
577
 
578
  exit_status = 0
579
  try:
580
+ demo = build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
581
  demo.queue(
582
  status_update_rate=10,
583
  api_open=False
584
  ).launch(
585
+ server_name=args.host,
586
+ server_port=args.port,
587
+ share=args.share
588
  )
589
 
590
  except Exception as e:
591
  print(e)
592
  exit_status = 1
593
  finally:
594
+ sys.exit(exit_status)