lukiod commited on
Commit
fff6204
1 Parent(s): e0e0a96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -155
app.py CHANGED
@@ -1,176 +1,62 @@
1
  import streamlit as st
2
- import torch
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
  from qwen_vl_utils import process_vision_info
5
- from byaldi import RAGMultiModalModel
6
  from PIL import Image
7
- import io
8
- import time
9
- import nltk
10
- from nltk.translate.bleu_score import sentence_bleu
11
-
12
- # Download NLTK data for BLEU score calculation
13
- nltk.download('punkt', quiet=True)
14
 
15
- # Load models and processors
16
  @st.cache_resource
17
- def load_models():
18
- RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
19
-
20
- qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
21
- "Qwen/Qwen2-VL-7B-Instruct",
22
- torch_dtype=torch.bfloat16,
23
- attn_implementation="flash_attention_2",
24
- device_map="auto",
25
- trust_remote_code=True
26
- ).cuda().eval()
27
-
28
- qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
29
-
30
- return RAG, qwen_model, qwen_processor
31
 
32
- RAG, qwen_model, qwen_processor = load_models()
33
 
34
- # Function to get current CUDA memory usage
35
- def get_cuda_memory_usage():
36
- return torch.cuda.memory_allocated() / 1024**2 # Convert to MB
37
 
38
- # Define processing functions
39
- def extract_text_with_colpali(image):
40
- start_time = time.time()
41
- start_memory = get_cuda_memory_usage()
42
-
43
- extracted_text = RAG.extract_text(image)
44
-
45
- end_time = time.time()
46
- end_memory = get_cuda_memory_usage()
47
-
48
- return extracted_text, {
49
- 'time': end_time - start_time,
50
- 'memory': end_memory - start_memory
51
- }
52
 
53
- def process_with_qwen(query, extracted_text, image, extract_mode=False):
54
- start_time = time.time()
55
- start_memory = get_cuda_memory_usage()
56
-
57
- if extract_mode:
58
- instruction = "Extract and list all text visible in this image, including both printed and handwritten text."
59
- else:
60
- instruction = f"Context: {extracted_text}\n\nQuery: {query}"
61
-
 
62
  messages = [
63
  {
64
  "role": "user",
65
  "content": [
66
- {
67
- "type": "text",
68
- "text": instruction
69
- },
70
- {
71
- "type": "image",
72
- "image": image,
73
- },
74
  ],
75
  }
76
  ]
77
- text = qwen_processor.apply_chat_template(
78
- messages, tokenize=False, add_generation_prompt=True
79
- )
80
- image_inputs, video_inputs = process_vision_info(messages)
81
- inputs = qwen_processor(
82
- text=[text],
83
- images=image_inputs,
84
- videos=video_inputs,
85
- padding=True,
86
- return_tensors="pt",
87
- )
88
- inputs = inputs.to("cuda")
89
- generated_ids = qwen_model.generate(**inputs, max_new_tokens=200)
90
- generated_ids_trimmed = [
91
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
92
- ]
93
- output_text = qwen_processor.batch_decode(
94
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
95
- )
96
-
97
- end_time = time.time()
98
- end_memory = get_cuda_memory_usage()
99
-
100
- return output_text[0], {
101
- 'time': end_time - start_time,
102
- 'memory': end_memory - start_memory
103
- }
104
-
105
- # Function to calculate BLEU score
106
- def calculate_bleu(reference, hypothesis):
107
- reference_tokens = nltk.word_tokenize(reference.lower())
108
- hypothesis_tokens = nltk.word_tokenize(hypothesis.lower())
109
- return sentence_bleu([reference_tokens], hypothesis_tokens)
110
-
111
- # Streamlit UI
112
- st.title("Document Processing with ColPali and Qwen")
113
-
114
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
115
- query = st.text_input("Enter your query:")
116
-
117
- if uploaded_file is not None and query:
118
- image = Image.open(uploaded_file)
119
- st.image(image, caption="Uploaded Image", use_column_width=True)
120
-
121
- if st.button("Process"):
122
- with st.spinner("Processing..."):
123
- # Extract text using ColPali
124
- colpali_extracted_text, colpali_metrics = extract_text_with_colpali(image)
125
-
126
- # Extract text using Qwen
127
- qwen_extracted_text, qwen_extract_metrics = process_with_qwen("", "", image, extract_mode=True)
128
-
129
- # Process the query with Qwen2, using both extracted text and image
130
- qwen_response, qwen_response_metrics = process_with_qwen(query, colpali_extracted_text, image)
131
-
132
- # Calculate BLEU score between ColPali and Qwen extractions
133
- bleu_score = calculate_bleu(colpali_extracted_text, qwen_extracted_text)
134
 
135
- # Display results
136
- st.subheader("Results")
137
- st.write("ColPali Extracted Text:")
138
- st.write(colpali_extracted_text)
139
-
140
- st.write("Qwen Extracted Text:")
141
- st.write(qwen_extracted_text)
142
-
143
- st.write("Qwen Response:")
144
- st.write(qwen_response)
145
 
146
- # Display metrics
147
- st.subheader("Metrics")
148
-
149
- st.write("ColPali Extraction:")
150
- st.write(f"Time: {colpali_metrics['time']:.2f} seconds")
151
- st.write(f"Memory: {colpali_metrics['memory']:.2f} MB")
152
-
153
- st.write("Qwen Extraction:")
154
- st.write(f"Time: {qwen_extract_metrics['time']:.2f} seconds")
155
- st.write(f"Memory: {qwen_extract_metrics['memory']:.2f} MB")
156
-
157
- st.write("Qwen Response:")
158
- st.write(f"Time: {qwen_response_metrics['time']:.2f} seconds")
159
- st.write(f"Memory: {qwen_response_metrics['memory']:.2f} MB")
160
-
161
- st.write(f"BLEU Score: {bleu_score:.4f}")
162
 
163
- st.markdown("""
164
- ## How to Use
 
165
 
166
- 1. Upload an image containing text or a document.
167
- 2. Enter your query about the document.
168
- 3. Click 'Process' to see the results.
169
 
170
- The app will display:
171
- - Text extracted by ColPali
172
- - Text extracted by Qwen
173
- - Qwen's response to your query
174
- - Performance metrics for each step
175
- - BLEU score comparing ColPali and Qwen extractions
176
- """)
 
1
  import streamlit as st
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
 
3
  from qwen_vl_utils import process_vision_info
 
4
  from PIL import Image
5
+ import torch
 
 
 
 
 
 
6
 
7
+ # Load the model and processor
8
  @st.cache_resource
9
+ def load_model():
10
+ # Load Qwen2-VL-7B on CPU
11
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
12
+ "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.float32, device_map="cpu"
13
+ )
14
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
15
+ return model, processor
 
 
 
 
 
 
 
16
 
17
+ model, processor = load_model()
18
 
19
+ # Streamlit Interface
20
+ st.title("Qwen2-VL-7B Multimodal Demo")
21
+ st.write("Upload an image and provide a text prompt to see the model's response.")
22
 
23
+ # Image uploader
24
+ image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Text input field
27
+ text = st.text_input("Enter a text description or query")
28
+
29
+ # If both image and text are provided
30
+ if image and text:
31
+ # Load image with PIL
32
+ img = Image.open(image)
33
+ st.image(img, caption="Uploaded Image", use_column_width=True)
34
+
35
+ # Prepare inputs for Qwen2-VL
36
  messages = [
37
  {
38
  "role": "user",
39
  "content": [
40
+ {"type": "image", "image": img},
41
+ {"type": "text", "text": text},
 
 
 
 
 
 
42
  ],
43
  }
44
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Prepare for inference
47
+ text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
48
+ image_inputs, _ = process_vision_info(messages)
49
+ inputs = processor(text=[text_input], images=image_inputs, padding=True, return_tensors="pt")
 
 
 
 
 
 
50
 
51
+ # Move tensors to CPU
52
+ inputs = inputs.to("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ # Run the model and generate output
55
+ with torch.no_grad():
56
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
57
 
58
+ # Decode the output text
59
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
 
60
 
61
+ # Display the response
62
+ st.write("Model's response:", generated_text)