KingNish commited on
Commit
f46faa5
1 Parent(s): 0f6e2c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -15
app.py CHANGED
@@ -17,9 +17,11 @@ model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
17
 
18
  processor = LlavaProcessor.from_pretrained(model_id)
19
 
20
- model = LlavaForConditionalGeneration.from_pretrained(model_id, low_cpu_mem_usage=True)
21
  model.to("cpu")
22
 
 
 
23
 
24
  def sample_frames(video_file) :
25
  try:
@@ -88,26 +90,51 @@ def respond(message, history):
88
  vqa = ""
89
 
90
  user_prompt = message
 
 
91
  # Handle image processing
92
- if message["files"]:
93
- image = user_prompt["files"][-1]
94
  txt = user_prompt["text"]
95
  img = user_prompt["files"]
 
 
 
 
 
96
 
97
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
98
  image_extensions = Image.registered_extensions()
99
  image_extensions = tuple([ex for ex, f in image_extensions.items()])
100
-
101
- if image.endswith(video_extensions):
102
- gr.Info(f"Analyzing {video_extensions} file")
103
- image = sample_frames(image)
104
- image_tokens = "<image>" * int(len(image))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  prompt = f"<|im_start|>user {image_tokens}\n{user_prompt}<|im_end|><|im_start|>assistant"
106
-
107
- elif image.endswith(image_extensions):
108
- gr.Info("Analyzing image")
109
- image = Image.open(image).convert("RGB")
110
- prompt = f"<|im_start|>user <image>\n{user_prompt}<|im_end|><|im_start|>assistant"
111
 
112
  inputs = processor(prompt, image, return_tensors="pt")
113
  streamer = TextIteratorStreamer(processor, skip_prompt=True, **{"skip_special_tokens": True})
@@ -116,7 +143,6 @@ def respond(message, history):
116
 
117
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
118
  thread.start()
119
- gr.Info("Generating output")
120
 
121
  buffer = ""
122
  for new_text in streamer:
@@ -132,7 +158,6 @@ def respond(message, history):
132
  {"type": "function", "function": {"name": "image_qna", "description": "Answer question asked by user related to image", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Question by user"}}, "required": ["query"]}}},
133
  ]
134
 
135
- message_text = message["text"]
136
  func_caller.append({"role": "user", "content": f'[SYSTEM]You are a helpful assistant. You have access to the following functions: \n {str(functions_metadata)}\n\nTo use these functions respond with:\n<functioncall> {{ "name": "function_name", "arguments": {{ "arg_1": "value_1", "arg_1": "value_1", ... }} }} </functioncall> [USER] {message} {vqa}'})
137
 
138
  response = client_gemma.chat_completion(func_caller, max_tokens=150)
 
17
 
18
  processor = LlavaProcessor.from_pretrained(model_id)
19
 
20
+ model = LlavaForConditionalGeneration.from_pretrained(model_id)
21
  model.to("cpu")
22
 
23
+ def replace_video_with_images(text, frames):
24
+ return text.replace("<video>", "<image>" * frames)
25
 
26
  def sample_frames(video_file) :
27
  try:
 
90
  vqa = ""
91
 
92
  user_prompt = message
93
+ message_text = message["text"]
94
+
95
  # Handle image processing
96
+ if message["files"]:
 
97
  txt = user_prompt["text"]
98
  img = user_prompt["files"]
99
+
100
+ if len(message["files"]) == 1:
101
+ image = [message["files"][0]]
102
+ elif len(message["files"]) > 1:
103
+ image = [for msg in message["files"]]
104
 
105
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
106
  image_extensions = Image.registered_extensions()
107
  image_extensions = tuple([ex for ex, f in image_extensions.items()])
108
+
109
+ if len(image) == 1:
110
+ if image[0].endswith(video_extensions):
111
+ gr.Info(f"Analyzing video")
112
+ image = sample_frames(image[0])
113
+ image_tokens = "<image>" * int(len(image))
114
+ prompt = f"<|im_start|>user {image_tokens}\n{user_prompt}<|im_end|><|im_start|>assistant"
115
+ elif image[0].endswith(image_extensions):
116
+ gr.Info("Analyzing image")
117
+ image = Image.open(image[0]).convert("RGB")
118
+ prompt = f"<|im_start|>user <image>\n{user_prompt}<|im_end|><|im_start|>assistant"
119
+
120
+ elif len(image) > 1:
121
+ image_list = []
122
+
123
+ for img in image:
124
+ if img.endswith(image_extensions):
125
+ gr.Info("Analyzing image")
126
+ img = Image.open(img).convert("RGB")
127
+ image_list.append(img)
128
+
129
+ elif img.endswith(video_extensions):
130
+ gr.Info(f"Analyzing video")
131
+ frames = sample_frames(img)
132
+ for frame in frames:
133
+ image_list.append(frame)
134
+
135
+ image_tokens = "<image>" * len(image_list)
136
  prompt = f"<|im_start|>user {image_tokens}\n{user_prompt}<|im_end|><|im_start|>assistant"
137
+ image = image_list
 
 
 
 
138
 
139
  inputs = processor(prompt, image, return_tensors="pt")
140
  streamer = TextIteratorStreamer(processor, skip_prompt=True, **{"skip_special_tokens": True})
 
143
 
144
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
145
  thread.start()
 
146
 
147
  buffer = ""
148
  for new_text in streamer:
 
158
  {"type": "function", "function": {"name": "image_qna", "description": "Answer question asked by user related to image", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Question by user"}}, "required": ["query"]}}},
159
  ]
160
 
 
161
  func_caller.append({"role": "user", "content": f'[SYSTEM]You are a helpful assistant. You have access to the following functions: \n {str(functions_metadata)}\n\nTo use these functions respond with:\n<functioncall> {{ "name": "function_name", "arguments": {{ "arg_1": "value_1", "arg_1": "value_1", ... }} }} </functioncall> [USER] {message} {vqa}'})
162
 
163
  response = client_gemma.chat_completion(func_caller, max_tokens=150)