Ivy1997 commited on
Commit
4aa36bf
·
verified ·
1 Parent(s): 7acf621

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -67
app.py CHANGED
@@ -1,10 +1,4 @@
1
  import gradio as gr
2
- import subprocess
3
- # subprocess.run(
4
- # "pip install flash-attn --no-build-isolation",
5
- # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
- # shell=True,
7
- # )
8
  from llava.model.builder import load_pretrained_model
9
  from llava.mm_utils import process_images, tokenizer_image_token
10
  from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
@@ -13,7 +7,6 @@ from PIL import Image
13
  import copy
14
  import torch
15
  import warnings
16
- import requests
17
 
18
  warnings.filterwarnings("ignore")
19
 
@@ -26,41 +19,26 @@ device_map = "auto"
26
  tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
27
  model.eval()
28
 
29
- def respond(
30
- message,
31
- history: list[tuple[str, str]],
32
- system_message,
33
- max_tokens,
34
- temperature,
35
- top_p,
36
- image=None,
37
- ):
38
- messages = [{"role": "system", "content": system_message}]
39
-
40
- for val in history:
41
- if val[0]:
42
- messages.append({"role": "user", "content": val[0]})
43
- if val[1]:
44
- messages.append({"role": "assistant", "content": val[1]})
45
-
46
- if image:
47
  # Load and process the image
48
- if isinstance(image, str):
49
- image = Image.open(image)
50
-
51
  image_tensor = process_images([image], image_processor, model.config)
52
  image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]
53
 
 
54
  conv_template = "qwen_1_5"
55
- question = DEFAULT_IMAGE_TOKEN + "\n" + message
56
  conv = copy.deepcopy(conv_templates[conv_template])
57
- conv.append_message(conv.roles[0], question)
58
  conv.append_message(conv.roles[1], None)
59
  prompt_question = conv.get_prompt()
60
 
 
61
  input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
62
  image_sizes = [image.size]
63
 
 
64
  cont = model.generate(
65
  input_ids,
66
  images=image_tensor,
@@ -70,45 +48,28 @@ def respond(
70
  max_new_tokens=max_tokens,
71
  )
72
 
73
- response = tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
74
- else:
75
- messages.append({"role": "user", "content": message})
76
-
77
- conv_template = "qwen_1_5"
78
- conv = copy.deepcopy(conv_templates[conv_template])
79
- conv.append_message(conv.roles[0], message)
80
- conv.append_message(conv.roles[1], None)
81
- prompt_question = conv.get_prompt()
82
-
83
- input_ids = tokenizer(prompt_question, return_tensors="pt", max_length=max_tokens, truncation=True).to(device)
84
-
85
- cont = model.generate(
86
- input_ids,
87
- do_sample=True,
88
- temperature=temperature,
89
- max_new_tokens=max_tokens,
90
- top_p=top_p,
91
- )
92
-
93
- response = tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
94
-
95
- yield response
96
-
97
- demo = gr.ChatInterface(
98
- respond,
99
- additional_inputs=[
100
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
101
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
102
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
103
- gr.Slider(
104
- minimum=0.1,
105
- maximum=1.0,
106
- value=0.95,
107
- step=0.05,
108
- label="Top-p (nucleus sampling)",
109
- ),
110
- gr.Image(type="filepath", label="Input Image (optional)"),
111
  ],
 
 
 
112
  )
113
 
114
  if __name__ == "__main__":
 
1
  import gradio as gr
 
 
 
 
 
 
2
  from llava.model.builder import load_pretrained_model
3
  from llava.mm_utils import process_images, tokenizer_image_token
4
  from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
 
7
  import copy
8
  import torch
9
  import warnings
 
10
 
11
  warnings.filterwarnings("ignore")
12
 
 
19
  tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
20
  model.eval()
21
 
22
+ def respond(image_path, question, temperature, max_tokens):
23
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Load and process the image
25
+ image = Image.open(image_path)
 
 
26
  image_tensor = process_images([image], image_processor, model.config)
27
  image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]
28
 
29
+ # Prepare the conversation template
30
  conv_template = "qwen_1_5"
31
+ formatted_question = DEFAULT_IMAGE_TOKEN + "\n" + question
32
  conv = copy.deepcopy(conv_templates[conv_template])
33
+ conv.append_message(conv.roles[0], formatted_question)
34
  conv.append_message(conv.roles[1], None)
35
  prompt_question = conv.get_prompt()
36
 
37
+ # Tokenize input
38
  input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
39
  image_sizes = [image.size]
40
 
41
+ # Generate response
42
  cont = model.generate(
43
  input_ids,
44
  images=image_tensor,
 
48
  max_new_tokens=max_tokens,
49
  )
50
 
51
+ text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
52
+ return text_outputs[0]
53
+ except Exception as e:
54
+ return f"Error: {str(e)}"
55
+
56
+ # Gradio Interface
57
+ def chat_interface(image, question, temperature, max_tokens):
58
+ if not image or not question:
59
+ return "Please provide both an image and a question."
60
+ return respond(image.name, question, temperature, max_tokens)
61
+
62
+ demo = gr.Interface(
63
+ fn=chat_interface,
64
+ inputs=[
65
+ gr.Image(type="file", label="Input Image"),
66
+ gr.Textbox(label="Question"),
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
68
+ gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max Tokens"),
 
 
 
 
 
 
 
69
  ],
70
+ outputs="text",
71
+ title="AI-Safeguard Ivy-VL-Llava Image Question Answering",
72
+ description="Upload an image and ask a question about it. The model will provide a response based on the visual and textual input."
73
  )
74
 
75
  if __name__ == "__main__":