arjunanand13 commited on
Commit
16cdbba
1 Parent(s): 2af2693

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +552 -0
app.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Usage
2
+ ### python main.py --mode interface
3
+ ### python main.py videos/Spirituality_1_clip.mp4 -n 3 --mode inference --model gemini
4
+ import gradio as gr
5
+ import os
6
+ import whisper
7
+ import cv2
8
+ import json
9
+ import tempfile
10
+ import torch
11
+ import transformers
12
+ from transformers import pipeline
13
+ import re
14
+ import time
15
+ from torch import cuda, bfloat16
16
+ from moviepy.editor import VideoFileClip
17
+ from image_caption import Caption
18
+ from pathlib import Path
19
+ from langchain import PromptTemplate
20
+ from langchain import LLMChain
21
+ from langchain.llms import HuggingFacePipeline
22
+ from difflib import SequenceMatcher
23
+ import argparse
24
+ import shutil
25
+ from PIL import Image
26
+ import google.generativeai as genai
27
+ from huggingface_hub import InferenceClient
28
+ from openai import OpenAI
29
+
30
+ class VideoClassifier:
31
+ global audio_time , setup_time , caption_time , classification_time
32
+ audio_time = 0
33
+ setup_time = 0
34
+ caption_time = 0
35
+ classification_time = 0
36
+ def __init__(self, no_of_frames, mode='interface',model='gemini'):
37
+
38
+ self.no_of_frames = no_of_frames
39
+ self.mode = mode
40
+ self.model_name = model.strip().lower()
41
+ print(self.model_name)
42
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
43
+ if self.model_name=='mistral':
44
+ print("Setting up Mistral model for Class Selection")
45
+ self.setup_mistral_model()
46
+ else :
47
+ print("Setting up Gemini model for Class Selection")
48
+ self.setup_gemini_model()
49
+ self.setup_paths()
50
+ self.hf_key = os.environ.get("HF_KEY", None)
51
+ """chatgpt 3.5"""
52
+ # self.chatgpt_client = OpenAI(api_key="sk-proj-KY1qI7zTpsUiJhMUHuNdT3BlbkFJLOjVnTUSpYJi87yUtSEI")
53
+ self.chatgpt_client= OpenAI(api_key="sk-proj-TVoFQ4X9apDUs0V6zCDIT3BlbkFJmWRNMgJ6fapge12zygzG")
54
+ # self.whisper_model = whisper.load_model("base")
55
+
56
+ def setup_paths(self):
57
+ self.path = './results'
58
+ if os.path.exists(self.path):
59
+ shutil.rmtree(self.path)
60
+ os.mkdir(self.path)
61
+
62
+ def setup_gemini_model(self):
63
+ self.genai = genai
64
+ self.genai.configure(api_key="AIzaSyAFG94rVbm9eWepO5uPGsMha8XJ-sHbMdA")
65
+ self.genai_model = genai.GenerativeModel('gemini-pro')
66
+ self.whisper_model = whisper.load_model("base")
67
+ self.img_cap = Caption()
68
+
69
+ def setup_mistral_space_model(self):
70
+ # if not self.hf_key:
71
+ # raise ValueError("Hugging Face API key is not set or invalid.")
72
+
73
+ self.client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
74
+ # self.client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
75
+ # self.client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
76
+ self.whisper_model = whisper.load_model("base")
77
+ self.img_cap = Caption()
78
+
79
+
80
+ def setup_mistral_model(self):
81
+ self.model_id = "mistralai/Mistral-7B-Instruct-v0.2"
82
+ self.device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
83
+ # self.device_name = torch.cuda.get_device_name()
84
+ # print(f"Using device: {self.device} ({self.device_name})")
85
+ bnb_config = transformers.BitsAndBytesConfig(
86
+ load_in_4bit=True,
87
+ bnb_4bit_quant_type='nf4',
88
+ bnb_4bit_use_double_quant=True,
89
+ bnb_4bit_compute_dtype=bfloat16,
90
+ )
91
+ hf_auth = self.hf_key
92
+ print(hf_auth)
93
+ model_config = transformers.AutoConfig.from_pretrained(
94
+ self.model_id,
95
+ # use_auth_token=hf_auth
96
+ )
97
+ self.model = transformers.AutoModelForCausalLM.from_pretrained(
98
+ self.model_id,
99
+ trust_remote_code=True,
100
+ config=model_config,
101
+ quantization_config=bnb_config,
102
+ # use_auth_token=hf_auth
103
+ )
104
+ self.model.eval()
105
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
106
+ self.model_id,
107
+ # use_auth_token=hf_auth
108
+ )
109
+ self.generate_text = transformers.pipeline(
110
+ model=self.model, tokenizer=self.tokenizer,
111
+ return_full_text=True,
112
+ task='text-generation',
113
+ temperature=0.01,
114
+ max_new_tokens=32
115
+ )
116
+ self.whisper_model = whisper.load_model("base")
117
+ self.img_cap = Caption()
118
+ self.llm = HuggingFacePipeline(pipeline=self.generate_text)
119
+
120
+ def audio_extraction(self,video_input):
121
+ """When running on local we use this library approach which consumes 3 seconds of gpu inference"""
122
+ global audio_time
123
+ start_time_audio = time.time()
124
+ print(f"Processing video: {video_input} with {self.no_of_frames} frames.")
125
+ mp4_file = video_input
126
+ video_name = mp4_file.split("/")[-1]
127
+ wav_file = "results/audiotrack.wav"
128
+ video_clip = VideoFileClip(mp4_file)
129
+ audioclip = video_clip.audio
130
+ wav_file = audioclip.write_audiofile(wav_file)
131
+ audioclip.close()
132
+ video_clip.close()
133
+ audiotrack = "results/audiotrack.wav"
134
+ result = self.whisper_model.transcribe(audiotrack, fp16=False)
135
+ transcript = result["text"]
136
+ print("TRANSCRIPT",transcript)
137
+ end_time_audio = time.time()
138
+ audio_time=end_time_audio-start_time_audio
139
+ # print("TIME TAKEN FOR AUDIO CONVERSION (WHISPER)",audio_time)
140
+
141
+ return transcript
142
+
143
+ def audio_extraction_space(self,video_input):
144
+ """When running the project in space we use model directly from huggingface to beat the inference time"""
145
+ MODEL_NAME = "openai/whisper-large-v3"
146
+ BATCH_SIZE = 8
147
+ device = "cuda" if torch.cuda.is_available() else "cpu"
148
+ global audio_time
149
+ start_time_audio = time.time()
150
+ print(f"Processing video: {video_input} with {self.no_of_frames} frames.")
151
+ mp4_file = video_input
152
+ video_name = mp4_file.split("/")[-1]
153
+ wav_file = "results/audiotrack.wav"
154
+ video_clip = VideoFileClip(mp4_file)
155
+ audioclip = video_clip.audio
156
+ wav_file = audioclip.write_audiofile(wav_file)
157
+ audioclip.close()
158
+ video_clip.close()
159
+ audiotrack = "results/audiotrack.wav"
160
+ pipe = pipeline(
161
+ "automatic-speech-recognition",
162
+ model=MODEL_NAME,
163
+ device=device
164
+ )
165
+ # if audio_file is None:
166
+ # return "No audio file submitted! Please upload or record an audio file before submitting your request."
167
+
168
+ # if not os.path.exists(audio_file):
169
+ # return "File does not exist. Please check the file path."
170
+ task="transcribe"
171
+ result = pipe(audiotrack, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)
172
+ return result["text"]
173
+
174
+ def audio_extraction_chatgptapi(self,video_input):
175
+ """For cpu inference , we use this function for faster api calling inference"""
176
+ global audio_time
177
+ start_time_audio = time.time()
178
+ print(f"Processing video: {video_input} with {self.no_of_frames} frames.")
179
+ mp4_file = video_input
180
+ video_name = mp4_file.split("/")[-1]
181
+ wav_file = "results/audiotrack.wav"
182
+ video_clip = VideoFileClip(mp4_file)
183
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
184
+ video_clip.audio.write_audiofile(temp_audio.name, codec='pcm_s16le', nbytes=2, fps=16000)
185
+ video_clip.close()
186
+
187
+ with open(temp_audio.name, 'rb') as audio_file:
188
+ transcription = self.chatgpt_client.audio.transcriptions.create(
189
+ model="whisper-1",
190
+ file=audio_file
191
+ )
192
+ print(transcription.text)
193
+ os.remove(temp_audio.name)
194
+ # audioclip = video_clip.audio
195
+ # wav_file = audioclip.write_audiofile(wav_file)
196
+ # audioclip.close()
197
+ # video_clip.close()
198
+ # audiotrack = "results/audiotrack.wav"
199
+ # # client = OpenAI(api_key="sk-proj-KY1qI7zTpsUiJhMUHuNdT3BlbkFJLOjVnTUSpYJi87yUtSEI")
200
+ # # audiotrack= open("audiotrack.wav", "rb")
201
+ # transcription = self.client.audio.transcriptions.create(
202
+ # model="whisper-1",
203
+ # file=audioclip
204
+ # )
205
+ # print(transcription.text)
206
+ return transcription.text
207
+
208
+ def generate_text(self, inputs, parameters=None):
209
+ if parameters is None:
210
+ parameters = {
211
+ "temperature": 0.7,
212
+ "max_new_tokens": 50,
213
+ "top_p": 0.9,
214
+ "repetition_penalty": 1.2
215
+ }
216
+
217
+ return self.client(inputs, parameters)
218
+ default_checkbox = []
219
+ def classify_video(self,video_input,checkbox=default_checkbox):
220
+ global classification_time , caption_time
221
+ print("checkbox",checkbox)
222
+ # transcript=self.audio_extraction_space(video_input)
223
+ try:
224
+ transcript=self.audio_extraction(video_input)
225
+ except:
226
+ transcript=self.audio_extraction_space(video_input)
227
+ # try:
228
+ # transcript=self.audio_extraction_chatgptapi(video_input)
229
+ # except :
230
+ # print("Chatgpt Key expired , inferencing using whisper library")
231
+ # try:
232
+ # transcript=self.audio_extraction(video_input)
233
+ # except:
234
+ # transcript=self.audio_extraction_space(video_input)
235
+ start_time_caption = time.time()
236
+ captions = ""
237
+
238
+ if checkbox==["Image Captions and Audio for Classification"]:
239
+ video = cv2.VideoCapture(video_input)
240
+ length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
241
+ no_of_frame = int(self.no_of_frames)
242
+ temp_div = length // no_of_frame
243
+ currentframe = 50
244
+ caption_text = []
245
+
246
+ for i in range(no_of_frame):
247
+ video.set(cv2.CAP_PROP_POS_FRAMES, currentframe)
248
+ ret, frame = video.read()
249
+ if ret:
250
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
251
+ image = Image.fromarray(frame)
252
+ content = self.img_cap.predict_image_caption_gemini(image)
253
+ print("content", content)
254
+ caption_text.append(content)
255
+ currentframe += temp_div - 1
256
+ else:
257
+ break
258
+
259
+ captions = ", ".join(caption_text)
260
+ print("CAPTIONS", captions)
261
+ video.release()
262
+ cv2.destroyAllWindows()
263
+
264
+ # print("TIME TAKEN FOR IMAGE CAPTIONING", end_time_caption-start_time_caption)
265
+
266
+ end_time_caption = time.time()
267
+ caption_time=end_time_caption-start_time_caption
268
+ start_time_generation = time.time()
269
+ main_categories = Path("main_classes.txt").read_text()
270
+ main_categories_list = ['Automotive', 'Books and Literature', 'Business and Finance', 'Careers', 'Education','Family and Relationships',
271
+ 'Fine Art', 'Food & Drink', 'Healthy Living', 'Hobbies & Interests', 'Home & Garden','Medical Health', 'Movies', 'Music and Audio',
272
+ 'News and Politics', 'Personal Finance', 'Pets', 'Pop Culture','Real Estate', 'Religion & Spirituality', 'Science', 'Shopping', 'Sports',
273
+ 'Style & Fashion','Technology & Computing', 'Television', 'Travel', 'Video Gaming']
274
+
275
+ generate_kwargs = {
276
+ "temperature": 0.9,
277
+ "max_new_tokens": 256,
278
+ "top_p": 0.95,
279
+ "repetition_penalty": 1.0,
280
+ "do_sample": True,
281
+ "seed": 42,
282
+ "return_full_text": False
283
+ }
284
+
285
+ template1 = '''Given below are the different type of main video classes
286
+ {main_categories}
287
+ You are a text classifier that catergorises the transcript and captions into one main class whose context match with one main class and only generate main class name no need of sub classe or explanation.
288
+ Give more importance to Transcript while classifying .
289
+ Transcript: {transcript}
290
+ Captions: {captions}
291
+ Return only the answer chosen from list and nothing else
292
+ Main-class => '''
293
+
294
+ prompt1 = PromptTemplate(template=template1, input_variables=['main_categories', 'transcript', 'captions'])
295
+ print("PROMPT 1",prompt1)
296
+ # print(self.model)
297
+ # print(f"Current model in use: {self.model}")
298
+ if self.model_name=='mistral':
299
+ try:
300
+ print("Entering mistral chain approach")
301
+ chain1 = LLMChain(llm=self.llm, prompt=prompt1)
302
+ main_class = chain1.predict(main_categories=main_categories, transcript=transcript, captions=captions)
303
+ except:
304
+ print("Entering mistral template approach")
305
+ prompt1 = template1.format(main_categories=main_categories, transcript=transcript, captions=captions)
306
+ messages = [{"role": "user", "content": prompt1}]
307
+ stream = self.client.chat_completion(messages, max_tokens=100)
308
+ main_class = stream.choices[0].message.content.strip()
309
+ # output = ""
310
+ # for response in stream:
311
+ # output += response['token'].text
312
+ # print("Streaming output:", output)
313
+
314
+ # main_class = output.strip()
315
+
316
+ print(main_class)
317
+ print("#######################################################")
318
+ try:
319
+ pattern = r"Main-class =>\s*(.+)"
320
+ match = re.search(pattern, main_class)
321
+ if match:
322
+ main_class = match.group(1).strip()
323
+ except:
324
+ main_class=main_class
325
+ else:
326
+ prompt_text = template1.format(main_categories=main_categories, transcript=transcript, captions=captions)
327
+ response = self.genai_model.generate_content(contents=prompt_text)
328
+ main_class = response.text
329
+
330
+ print(main_class)
331
+ print("#######################################################")
332
+ print("MAIN CLASS: ",main_class)
333
+ def category_class(class_name,categories_list):
334
+ def similar(str1, str2):
335
+ return SequenceMatcher(None, str1, str2).ratio()
336
+ index_no = 0
337
+ sim = 0
338
+ for sub in categories_list:
339
+ res = similar(class_name, sub)
340
+ if res>sim:
341
+ sim = res
342
+ index_no = categories_list.index(sub)
343
+ class_name = categories_list[index_no]
344
+ return class_name
345
+
346
+ if main_class not in main_categories_list:
347
+ main_class = category_class(main_class,main_categories_list)
348
+ print("POST PROCESSED MAIN CLASS : ",main_class)
349
+ tier_1_index_no = main_categories_list.index(main_class) + 1
350
+
351
+ with open('categories_json.txt') as f:
352
+ data = json.load(f)
353
+ sub_categories_list = data[main_class]
354
+ print("SUB CATEGORIES LIST",sub_categories_list)
355
+ with open("sub_categories.txt", "w") as f:
356
+ no = 1
357
+
358
+ # print(data[main_class])
359
+ for i in data[main_class]:
360
+ f.write(str(no)+')'+str(i) + '\n')
361
+ no = no+1
362
+ sub_categories = Path("sub_categories.txt").read_text()
363
+
364
+ template2 = '''Given below are the sub classes of {main_class}.
365
+ {sub_categories}
366
+ You are a text classifier that catergorises the transcript and captions into one sub class whose context match with one sub class and only generate sub class name, Don't give explanation .
367
+ Give more importance to Transcript while classifying .
368
+ Transcript: {transcript}
369
+ Captions: {captions}
370
+ Return only the Sub-class answer chosen from list and nothing else
371
+ Answer in the format:
372
+ Main-class => {main_class}
373
+ Sub-class =>
374
+ '''
375
+
376
+ prompt2 = PromptTemplate(template=template2, input_variables=['sub_categories', 'transcript', 'captions','main_class'])
377
+
378
+ if self.model_name=='mistral':
379
+ try:
380
+ chain2 = LLMChain(llm=self.llm, prompt=prompt2)
381
+ sub_class = chain2.predict(sub_categories=sub_categories, transcript=transcript, captions=captions,main_class=main_class)
382
+ except:
383
+ prompt2 = template2.format(sub_categories=sub_categories, transcript=transcript, captions=captions,main_class=main_class)
384
+ messages = [{"role": "user", "content": prompt2}]
385
+ stream = self.client.chat_completion(messages, max_tokens=100)
386
+ sub_class = stream.choices[0].message.content.strip()
387
+
388
+ print("Preprocess Answer",sub_class)
389
+
390
+ try:
391
+ pattern = r"Sub-class =>\s*(.+)"
392
+ match = re.search(pattern, sub_class)
393
+ if match:
394
+ sub_class = match.group(1).strip()
395
+ except:
396
+ subclass=sub_class
397
+ else:
398
+ prompt_text2 = template1.format(main_categories=main_categories, transcript=transcript, captions=captions)
399
+ response = self.genai_model.generate_content(contents=prompt_text2)
400
+ sub_class = response.text
401
+ print("Preprocess Answer",sub_class)
402
+
403
+ print("SUB CLASS",sub_class)
404
+ if sub_class not in sub_categories_list:
405
+ sub_class = category_class(sub_class,sub_categories_list)
406
+ print("POST PROCESSED SUB CLASS",sub_class)
407
+ tier_2_index_no = sub_categories_list.index(sub_class) + 1
408
+ print("ANSWER:",sub_class)
409
+ final_answer = (f"Tier 1 category : IAB{tier_1_index_no} : {main_class}\nTier 2 category : IAB{tier_1_index_no}-{tier_2_index_no} : {sub_class}")
410
+
411
+ first_video = os.path.join(os.path.dirname(__file__), "American_football_heads_to_India_clip.mp4")
412
+ second_video = os.path.join(os.path.dirname(__file__), "PersonalFinance_clip.mp4")
413
+
414
+ # return final_answer, first_video, second_video
415
+ end_time_generation = time.time()
416
+ classification_time = end_time_generation-start_time_generation
417
+ print ("MODEL USED :",self.model_name)
418
+ print("MODEL SETUP TIME :",setup_time)
419
+ print("TIME TAKEN FOR AUDIO CONVERSION (WHISPER) :",audio_time)
420
+ print("TIME TAKEN FOR IMAGE CAPTIONING :", caption_time)
421
+ print("TIME TAKEN FOR CLASS GENERATION :",classification_time)
422
+ print("TOTAL INFERENCE TIME :",audio_time+caption_time+classification_time)
423
+ return final_answer
424
+
425
+
426
+ def save_model_choice(self,model_name):
427
+ global setup_time
428
+ start_time_setup = time.time()
429
+
430
+ self.model_name = model_name
431
+ if self.model_name=='mistral':
432
+ print("Setting up Mistral model for Class Selection")
433
+ self.setup_mistral_space_model()
434
+ else :
435
+ print("Setting up Gemini model for Class Selection")
436
+ self.setup_gemini_model()
437
+ end_time_setup = time.time()
438
+ setup_time=end_time_setup-start_time_setup
439
+ # print("MODEL SETUP TIME",setup_time)
440
+
441
+ return "Model selected: " + model_name
442
+
443
+ def launch_interface(self):
444
+ css_code = """
445
+ /* Highlight the second tab in the tabbed interface */
446
+ .gradio-container .tab-labels {
447
+ background-color: #d9d9d9;
448
+ }
449
+
450
+ /* Highlight the second tab specifically */
451
+ .gradio-container .tab-label:nth-child(2) {
452
+ background-color: #ffcc00; /* Yellow color for the second tab */
453
+ }
454
+
455
+ /* Bounding box for inputs and outputs */
456
+ .gradio-container .gr-box {
457
+ border: 2px solid #333333; /* Dark border for inputs and outputs */
458
+ padding: 10px;
459
+ border-radius: 5px;
460
+ }
461
+
462
+ .gradio-container .input-section {
463
+ border: 2px solid #009688; /* Teal border for input section */
464
+ padding: 10px;
465
+ border-radius: 5px;
466
+ margin-bottom: 10px;
467
+ }
468
+
469
+ .gradio-container .output-section {
470
+ border: 2px solid #ff5722; /* Orange border for output section */
471
+ padding: 10px;
472
+ border-radius: 5px;
473
+ margin-bottom: 10px;
474
+ }
475
+
476
+ /* Alignments for clean design */
477
+ .gradio-container .gr-row {
478
+ justify-content: center;
479
+ text-align: center;
480
+ }
481
+ """
482
+ # css_code = """
483
+ # .gradio-container {background-color: #FFFFFF;color:#000000;background-size: 200px; background-image:url(https://gitlab.ignitarium.in/saran/logo/-/raw/aab7c77b4816b8a4bbdc5588eb57ce8b6c15c72d/ign_logo_white.png);background-repeat:no-repeat; position:relative; top:1px; left:5px; padding: 50px;text-align: right;background-position: right top;}
484
+ # """
485
+ # css_code += """
486
+ # :root {
487
+ # --body-background-fill: #FFFFFF; /* New value */
488
+ # }
489
+ # """
490
+ # css_code += """
491
+ # :root {
492
+ # --body-background-fill: #000000; /* New value */
493
+ # }
494
+ # """
495
+
496
+ interface_1 = gr.Interface(
497
+ self.save_model_choice,
498
+ inputs=gr.Dropdown(choices=['gemini', 'mistral'], label="Select Model", info="Default model: Gemini"),
499
+ # outputs=interface_1_output,
500
+ outputs="text"
501
+
502
+ )
503
+
504
+ video_examples = [
505
+ [os.path.join(os.path.dirname(__file__), "American_football_heads_to_India_clip.mp4")],
506
+ [os.path.join(os.path.dirname(__file__), "PersonalFinance_clip.mp4")],
507
+ [os.path.join(os.path.dirname(__file__), "Motorcycle_clip.mp4")],
508
+ [os.path.join(os.path.dirname(__file__), "Spirituality_1_clip.mp4")],
509
+ [os.path.join(os.path.dirname(__file__), "Science_clip.mp4")]
510
+ ]
511
+
512
+ # Define the checkbox for additional feature control
513
+ checkbox = gr.CheckboxGroup(
514
+ ["Image Captions and Audio for Classification"],
515
+ label="Features",
516
+ info="default : Audio for classification",
517
+ )
518
+
519
+ default_checkbox = []
520
+
521
+ demo = gr.Interface(fn=self.classify_video, inputs=["playablevideo",checkbox],allow_flagging='never', examples=video_examples,
522
+ cache_examples=False, outputs=["text"],
523
+ css=css_code, title="Interactive Advertising Bureau (IAB) compliant Video-Ad classification")
524
+ # demo.launch(debug=True)
525
+
526
+ gr.TabbedInterface([interface_1, demo], ["Model Selection", "Video Classification"]).launch(debug=True)
527
+
528
+ def run_inference(self, video_path,model):
529
+ result = self.classify_video(video_path)
530
+ print(result)
531
+
532
+
533
+ if __name__ == "__main__":
534
+ parser = argparse.ArgumentParser(description='Process some videos.')
535
+ parser.add_argument("video_path", nargs='?', default=None, help="Path to the video file")
536
+ parser.add_argument("-n", "--no_of_frames", type=int, default=3, help="Number of frames for image captioning")
537
+ parser.add_argument("--mode", choices=['interface', 'inference'], default='interface', help="Mode of operation: interface or inference")
538
+ parser.add_argument("--model", choices=['gemini','mistral'],default='gemini',help="Model for inference")
539
+
540
+ args = parser.parse_args()
541
+
542
+ vc = VideoClassifier(no_of_frames=args.no_of_frames, mode=args.mode , model=args.model)
543
+
544
+
545
+ if args.mode == 'interface':
546
+ vc.launch_interface()
547
+ elif args.mode == 'inference' and args.video_path and args.model:
548
+ vc.run_inference(args.video_path,args.model)
549
+ else:
550
+ print("Error: No video path/model provided for inference mode.")
551
+
552
+