|
import gradio as gr |
|
import io |
|
import sys |
|
import time |
|
import dataclasses |
|
from pathlib import Path |
|
import os |
|
from enum import auto, Enum |
|
from typing import List, Tuple, Any |
|
from utility import prediction_guard_llava_conv |
|
import lancedb |
|
from utility import load_json_file |
|
from mm_rag.embeddings.bridgetower_embeddings import BridgeTowerEmbeddings |
|
from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB |
|
from mm_rag.MLM.client import PredictionGuardClient |
|
from mm_rag.MLM.lvlm import LVLM |
|
from PIL import Image |
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda |
|
from moviepy.video.io.VideoFileClip import VideoFileClip |
|
from utility import prediction_guard_llava_conv, encode_image, Conversation, lvlm_inference_with_conversation |
|
|
|
server_error_msg="**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" |
|
|
|
|
|
def split_video(video_path, timestamp_in_ms, output_video_path: str = "./shared_data/splitted_videos", output_video_name: str="video_tmp.mp4", play_before_sec: int=3, play_after_sec: int=3): |
|
timestamp_in_sec = int(timestamp_in_ms / 1000) |
|
|
|
Path(output_video_path).mkdir(parents=True, exist_ok=True) |
|
output_video = os.path.join(output_video_path, output_video_name) |
|
with VideoFileClip(video_path) as video: |
|
duration = video.duration |
|
start_time = max(timestamp_in_sec - play_before_sec, 0) |
|
end_time = min(timestamp_in_sec + play_after_sec, duration) |
|
new = video.subclip(start_time, end_time) |
|
new.write_videofile(output_video, audio_codec='aac') |
|
return output_video |
|
|
|
|
|
prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}""" |
|
|
|
|
|
def get_default_rag_chain(): |
|
|
|
LANCEDB_HOST_FILE = "./shared_data/.lancedb" |
|
|
|
TBL_NAME = "demo_tbl" |
|
|
|
|
|
db = lancedb.connect(LANCEDB_HOST_FILE) |
|
|
|
|
|
embedder = BridgeTowerEmbeddings() |
|
|
|
|
|
vectorstore = MultimodalLanceDB(uri=LANCEDB_HOST_FILE, embedding=embedder, table_name=TBL_NAME) |
|
|
|
retriever_module = vectorstore.as_retriever(search_type='similarity', search_kwargs={"k": 1}) |
|
|
|
|
|
client = PredictionGuardClient() |
|
|
|
lvlm_inference_module = LVLM(client=client) |
|
|
|
def prompt_processing(input): |
|
|
|
retrieved_results, user_query = input['retrieved_results'], input['user_query'] |
|
|
|
retrieved_result = retrieved_results[0] |
|
|
|
|
|
|
|
metadata_retrieved_video_segment = retrieved_result.metadata['metadata'] |
|
|
|
|
|
transcript = metadata_retrieved_video_segment['transcript'] |
|
frame_path = metadata_retrieved_video_segment['extracted_frame_path'] |
|
return { |
|
'prompt': prompt_template.format(transcript=transcript, user_query=user_query), |
|
'image' : frame_path, |
|
'metadata' : metadata_retrieved_video_segment, |
|
} |
|
|
|
prompt_processing_module = RunnableLambda(prompt_processing) |
|
|
|
|
|
mm_rag_chain_with_retrieved_image = ( |
|
RunnableParallel({"retrieved_results": retriever_module , |
|
"user_query": RunnablePassthrough()}) |
|
| prompt_processing_module |
|
| RunnableParallel({'final_text_output': lvlm_inference_module, |
|
'input_to_lvlm' : RunnablePassthrough()}) |
|
) |
|
return mm_rag_chain_with_retrieved_image |
|
|
|
class SeparatorStyle(Enum): |
|
"""Different separator style.""" |
|
SINGLE = auto() |
|
|
|
@dataclasses.dataclass |
|
class GradioInstance: |
|
"""A class that keeps all conversation history.""" |
|
system: str |
|
roles: List[str] |
|
messages: List[List[str]] |
|
offset: int |
|
sep_style: SeparatorStyle = SeparatorStyle.SINGLE |
|
sep: str = "\n" |
|
sep2: str = None |
|
version: str = "Unknown" |
|
path_to_img: str = None |
|
video_title: str = None |
|
path_to_video: str = None |
|
caption: str = None |
|
mm_rag_chain: Any = None |
|
|
|
skip_next: bool = False |
|
|
|
def _template_caption(self): |
|
out = "" |
|
if self.caption is not None: |
|
out = f"The caption associated with the image is '{self.caption}'. " |
|
return out |
|
|
|
def get_prompt_for_rag(self): |
|
messages = self.messages |
|
assert len(messages) == 2, "length of current conversation should be 2" |
|
assert messages[1][1] is None, "the first response message of current conversation should be None" |
|
ret = messages[0][1] |
|
return ret |
|
|
|
def get_conversation_for_lvlm(self): |
|
pg_conv = prediction_guard_llava_conv.copy() |
|
image_path = self.path_to_img |
|
b64_img = encode_image(image_path) |
|
for i, (role, msg) in enumerate(self.messages[self.offset:]): |
|
if msg is None: |
|
break |
|
if i == 0: |
|
pg_conv.append_message(prediction_guard_llava_conv.roles[0], [msg, b64_img]) |
|
elif i == len(self.messages[self.offset:]) - 2: |
|
pg_conv.append_message(role, [prompt_template.format(transcript=self.caption, user_query=msg)]) |
|
else: |
|
pg_conv.append_message(role, [msg]) |
|
return pg_conv |
|
|
|
def append_message(self, role, message): |
|
self.messages.append([role, message]) |
|
|
|
def get_images(self, return_pil=False): |
|
images = [] |
|
if self.path_to_img is not None: |
|
path_to_image = self.path_to_img |
|
images.append(path_to_image) |
|
return images |
|
|
|
def to_gradio_chatbot(self): |
|
ret = [] |
|
for i, (role, msg) in enumerate(self.messages[self.offset:]): |
|
if i % 2 == 0: |
|
if type(msg) is tuple: |
|
import base64 |
|
from io import BytesIO |
|
msg, image, image_process_mode = msg |
|
max_hw, min_hw = max(image.size), min(image.size) |
|
aspect_ratio = max_hw / min_hw |
|
max_len, min_len = 800, 400 |
|
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) |
|
longest_edge = int(shortest_edge * aspect_ratio) |
|
W, H = image.size |
|
if H > W: |
|
H, W = longest_edge, shortest_edge |
|
else: |
|
H, W = shortest_edge, longest_edge |
|
image = image.resize((W, H)) |
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
img_b64_str = base64.b64encode(buffered.getvalue()).decode() |
|
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />' |
|
msg = img_str + msg.replace('<image>', '').strip() |
|
ret.append([msg, None]) |
|
else: |
|
ret.append([msg, None]) |
|
else: |
|
ret[-1][-1] = msg |
|
return ret |
|
|
|
def copy(self): |
|
return GradioInstance( |
|
system=self.system, |
|
roles=self.roles, |
|
messages=[[x, y] for x, y in self.messages], |
|
offset=self.offset, |
|
sep_style=self.sep_style, |
|
sep=self.sep, |
|
sep2=self.sep2, |
|
version=self.version, |
|
mm_rag_chain=self.mm_rag_chain, |
|
) |
|
|
|
def dict(self): |
|
return { |
|
"system": self.system, |
|
"roles": self.roles, |
|
"messages": self.messages, |
|
"offset": self.offset, |
|
"sep": self.sep, |
|
"sep2": self.sep2, |
|
"path_to_img": self.path_to_img, |
|
"video_title" : self.video_title, |
|
"path_to_video": self.path_to_video, |
|
"caption" : self.caption, |
|
} |
|
def get_path_to_subvideos(self): |
|
if self.video_title is not None and self.path_to_img is not None: |
|
info = video_helper_map[self.video_title] |
|
path = info['path'] |
|
prefix = info['prefix'] |
|
vid_index = self.path_to_img.split('/')[-1] |
|
vid_index = vid_index.split('_')[-1] |
|
vid_index = vid_index.replace('.jpg', '') |
|
ret = f"{prefix}{vid_index}.mp4" |
|
ret = os.path.join(path, ret) |
|
return ret |
|
elif self.path_to_video is not None: |
|
return self.path_to_video |
|
return None |
|
|
|
def get_gradio_instance(mm_rag_chain=None): |
|
if mm_rag_chain is None: |
|
mm_rag_chain = get_default_rag_chain() |
|
|
|
instance = GradioInstance( |
|
system="", |
|
roles=prediction_guard_llava_conv.roles, |
|
messages=[], |
|
offset=0, |
|
sep_style=SeparatorStyle.SINGLE, |
|
sep="\n", |
|
path_to_img=None, |
|
video_title=None, |
|
caption=None, |
|
mm_rag_chain=mm_rag_chain, |
|
) |
|
return instance |
|
|
|
gr.set_static_paths(paths=["./assets/"]) |
|
theme = gr.themes.Base( |
|
primary_hue=gr.themes.Color( |
|
c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"), |
|
secondary_hue=gr.themes.Color( |
|
c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"), |
|
).set( |
|
body_background_fill_dark='*primary_950', |
|
body_text_color_dark='*neutral_300', |
|
border_color_accent='*primary_700', |
|
border_color_accent_dark='*neutral_800', |
|
block_background_fill_dark='*primary_950', |
|
block_border_width='2px', |
|
block_border_width_dark='2px', |
|
button_primary_background_fill_dark='*primary_500', |
|
button_primary_border_color_dark='*primary_500' |
|
) |
|
|
|
css=''' |
|
@font-face { |
|
font-family: IntelOne; |
|
src: url("/file=./assets/intelone-bodytext-font-family-regular.ttf"); |
|
} |
|
.gradio-container {background-color: #0a0c2b} |
|
table { |
|
border-collapse: collapse; |
|
border: none; |
|
} |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
html_title = ''' |
|
<table style="bordercolor=#0a0c2b; border=0"> |
|
<tr style="height:150px; border:0"> |
|
<td style="border:0"><img src="/file=./assets/header.png"></td> |
|
</tr> |
|
</table> |
|
|
|
''' |
|
|
|
|
|
dropdown_list = [ |
|
"What is the name of one of the astronauts?", |
|
"An astronaut's spacewalk", |
|
"What does the astronaut say?", |
|
|
|
] |
|
|
|
no_change_btn = gr.Button() |
|
enable_btn = gr.Button(interactive=True) |
|
disable_btn = gr.Button(interactive=False) |
|
|
|
def clear_history(state, request: gr.Request): |
|
state = get_gradio_instance(state.mm_rag_chain) |
|
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1 |
|
|
|
def add_text(state, text, request: gr.Request): |
|
if len(text) <= 0 : |
|
state.skip_next = True |
|
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1 |
|
|
|
text = text[:1536] |
|
|
|
state.append_message(state.roles[0], text) |
|
state.append_message(state.roles[1], None) |
|
state.skip_next = False |
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1 |
|
|
|
def http_bot( |
|
state, request: gr.Request |
|
): |
|
start_tstamp = time.time() |
|
|
|
if state.skip_next: |
|
|
|
path_to_sub_videos = state.get_path_to_subvideos() |
|
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1 |
|
return |
|
|
|
if len(state.messages) == state.offset + 2: |
|
|
|
new_state = get_gradio_instance(state.mm_rag_chain) |
|
new_state.append_message(new_state.roles[0], state.messages[-2][1]) |
|
new_state.append_message(new_state.roles[1], None) |
|
state = new_state |
|
|
|
all_images = state.get_images(return_pil=False) |
|
|
|
|
|
is_very_first_query = True |
|
if len(all_images) == 0: |
|
|
|
|
|
prompt_or_conversation = state.get_prompt_for_rag() |
|
else: |
|
|
|
is_very_first_query = False |
|
prompt_or_conversation = state.get_conversation_for_lvlm() |
|
|
|
if is_very_first_query: |
|
executor = state.mm_rag_chain |
|
else: |
|
executor = lvlm_inference_with_conversation |
|
|
|
state.messages[-1][-1] = "▌" |
|
path_to_sub_videos = state.get_path_to_subvideos() |
|
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1 |
|
|
|
try: |
|
if is_very_first_query: |
|
|
|
response = executor.invoke(prompt_or_conversation) |
|
message = response['final_text_output'] |
|
if 'metadata' in response['input_to_lvlm']: |
|
metadata = response['input_to_lvlm']['metadata'] |
|
if (state.path_to_img is None |
|
and 'input_to_lvlm' in response |
|
and 'image' in response['input_to_lvlm'] |
|
): |
|
state.path_to_img = response['input_to_lvlm']['image'] |
|
|
|
if state.path_to_video is None and 'video_path' in metadata: |
|
video_path = metadata['video_path'] |
|
mid_time_ms = metadata['mid_time_ms'] |
|
splited_video_path = split_video(video_path, mid_time_ms) |
|
state.path_to_video = splited_video_path |
|
|
|
if state.caption is None and 'transcript' in metadata: |
|
state.caption = metadata['transcript'] |
|
else: |
|
raise ValueError("Response's format is changed") |
|
else: |
|
|
|
message = executor(prompt_or_conversation) |
|
|
|
except Exception as e: |
|
print(e) |
|
state.messages[-1][-1] = server_error_msg |
|
yield (state, state.to_gradio_chatbot(), None) + ( |
|
enable_btn, |
|
) |
|
return |
|
|
|
state.messages[-1][-1] = message |
|
path_to_sub_videos = state.get_path_to_subvideos() |
|
|
|
|
|
|
|
|
|
|
|
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1 |
|
|
|
finish_tstamp = time.time() |
|
return |
|
|
|
def get_demo(rag_chain=None): |
|
if rag_chain is None: |
|
rag_chain = get_default_rag_chain() |
|
|
|
with gr.Blocks(theme=theme, css=css) as demo: |
|
|
|
instance = get_gradio_instance(rag_chain) |
|
state = gr.State(instance) |
|
demo.load( |
|
None, |
|
None, |
|
js=""" |
|
() => { |
|
const params = new URLSearchParams(window.location.search); |
|
if (!params.has('__theme')) { |
|
params.set('__theme', 'dark'); |
|
window.location.search = params.toString(); |
|
} |
|
}""", |
|
) |
|
gr.HTML(value=html_title) |
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
video = gr.Video(height=512, width=512, elem_id="video", interactive=False ) |
|
with gr.Column(scale=7): |
|
chatbot = gr.Chatbot( |
|
elem_id="chatbot", label="Multimodal RAG Chatbot", height=512, |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=8): |
|
|
|
textbox = gr.Dropdown( |
|
dropdown_list, |
|
allow_custom_value=True, |
|
|
|
|
|
label="Query", |
|
info="Enter your query here or choose a sample from the dropdown list!" |
|
) |
|
with gr.Column(scale=1, min_width=50): |
|
submit_btn = gr.Button( |
|
value="Send", variant="primary", interactive=True |
|
) |
|
with gr.Row(elem_id="buttons") as button_row: |
|
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) |
|
|
|
btn_list = [clear_btn] |
|
|
|
clear_btn.click( |
|
clear_history, [state], [state, chatbot, textbox, video] + btn_list |
|
) |
|
submit_btn.click( |
|
add_text, |
|
[state, textbox], |
|
[state, chatbot, textbox,] + btn_list, |
|
).then( |
|
http_bot, |
|
[state], |
|
[state, chatbot, video] + btn_list, |
|
) |
|
return demo |
|
|
|
|