VinitT commited on
Commit
36d8cb0
·
verified ·
1 Parent(s): 2f43b7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -45
app.py CHANGED
@@ -2,57 +2,96 @@ import streamlit as st
2
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
3
  from PIL import Image
4
  import torch
 
 
5
 
6
  # Load the processor and model directly
7
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
8
  model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
9
 
 
 
 
 
10
  # Streamlit app
11
- st.title("Image Description Generator")
12
 
13
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
14
 
15
  if uploaded_file is not None:
16
- # Open the image
17
- image = Image.open(uploaded_file)
18
- st.image(image, caption='Uploaded Image.', use_column_width=True)
19
- st.write("Generating description...")
20
-
21
- messages = [
22
- {
23
- "role": "user",
24
- "content": [
25
- {
26
- "type": "image",
27
- "image": image,
28
- },
29
- {"type": "text", "text": "Describe this image."},
30
- ],
31
- }
32
- ]
33
-
34
- # Preparation for inference
35
- text = processor.apply_chat_template(
36
- messages, tokenize=False, add_generation_prompt=True
37
- )
38
-
39
- # Pass the image to the processor
40
- inputs = processor(
41
- text=[text],
42
- images=[image],
43
- padding=True,
44
- return_tensors="pt",
45
- )
46
- inputs = inputs.to("cpu")
47
-
48
- # Inference: Generation of the output
49
- generated_ids = model.generate(**inputs, max_new_tokens=128)
50
- generated_ids_trimmed = [
51
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
52
- ]
53
- output_text = processor.batch_decode(
54
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
55
- )
56
-
57
- st.write("Description:")
58
- st.write(output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
3
  from PIL import Image
4
  import torch
5
+ import cv2
6
+ import tempfile
7
 
8
  # Load the processor and model directly
9
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
10
  model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
11
 
12
+ # Check if CUDA is available and set the device accordingly
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
+
16
  # Streamlit app
17
+ st.title("Media Description Generator")
18
 
19
+ uploaded_file = st.file_uploader("Choose an image or video...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"])
20
 
21
  if uploaded_file is not None:
22
+ file_type = uploaded_file.type.split('/')[0]
23
+
24
+ if file_type == 'image':
25
+ # Open the image
26
+ image = Image.open(uploaded_file)
27
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
28
+ st.write("Generating description...")
29
+
30
+ elif file_type == 'video':
31
+ # Save the uploaded video to a temporary file
32
+ tfile = tempfile.NamedTemporaryFile(delete=False)
33
+ tfile.write(uploaded_file.read())
34
+
35
+ # Open the video file
36
+ cap = cv2.VideoCapture(tfile.name)
37
+
38
+ # Extract the first frame
39
+ ret, frame = cap.read()
40
+ if not ret:
41
+ st.error("Failed to read the video file.")
42
+ st.stop()
43
+ else:
44
+ # Convert the frame to an image
45
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
46
+ st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True)
47
+ st.write("Generating description...")
48
+
49
+ # Release the video capture object
50
+ cap.release()
51
+
52
+ else:
53
+ st.error("Unsupported file type.")
54
+ st.stop()
55
+
56
+ # Add a text input for the user to ask a question
57
+ user_question = st.text_input("Ask a question about the image or video:")
58
+
59
+ if user_question:
60
+ messages = [
61
+ {
62
+ "role": "user",
63
+ "content": [
64
+ {
65
+ "type": "image",
66
+ "image": image,
67
+ },
68
+ {"type": "text", "text": user_question},
69
+ ],
70
+ }
71
+ ]
72
+
73
+ # Preparation for inference
74
+ text = processor.apply_chat_template(
75
+ messages, tokenize=False, add_generation_prompt=True
76
+ )
77
+
78
+ # Pass the image to the processor
79
+ inputs = processor(
80
+ text=[text],
81
+ images=[image],
82
+ padding=True,
83
+ return_tensors="pt",
84
+ )
85
+ inputs = inputs.to(device) # Ensure inputs are on the same device as the model
86
+
87
+ # Inference: Generation of the output
88
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
89
+ generated_ids_trimmed = [
90
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
91
+ ]
92
+ output_text = processor.batch_decode(
93
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
94
+ )
95
+
96
+ st.write("Description:")
97
+ st.write(output_text[0])