Rick7799 commited on
Commit
3dad239
1 Parent(s): 8c018f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -11
app.py CHANGED
@@ -1,18 +1,35 @@
1
  import gradio as gr
2
- from transformers import RAGMultiModalModel # Importing the ColPali model
 
 
3
 
4
- # Initialize the ColPali model
5
- model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")
 
 
6
 
7
  def extract_and_search(image, keyword):
8
- # Use the model to extract text from the image
9
- inputs = {"images": [image]}
10
- extracted_text = model.generate(**inputs) # Replace with actual prediction method
11
-
12
- # Perform keyword search
13
- matching_lines = [line for line in extracted_text.splitlines() if keyword.lower() in line.lower()]
14
-
15
- return extracted_text, matching_lines
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Create Gradio interface
18
  interface = gr.Interface(
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+ from PIL import Image
5
 
6
+ # Load the ColPali model and tokenizer from Hugging Face
7
+ model_name = "vidore/colpali-v1.2" # Use the correct model identifier
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
 
11
  def extract_and_search(image, keyword):
12
+ try:
13
+ # Convert image to RGB if it's not already in that format
14
+ if image.mode != 'RGB':
15
+ image = image.convert('RGB')
16
+
17
+ # Preprocess image: convert to tensor format required by the model
18
+ inputs = tokenizer(images=image, return_tensors="pt") # Adjust as necessary for your input requirements
19
+
20
+ # Extract text from image using ColPali model
21
+ with torch.no_grad(): # Disable gradient calculation for inference
22
+ outputs = model.generate(**inputs)
23
+
24
+ # Decode outputs to text
25
+ extracted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
+
27
+ # Perform keyword search
28
+ matching_lines = [line for line in extracted_text.splitlines() if keyword.lower() in line.lower()]
29
+
30
+ return extracted_text, matching_lines
31
+ except Exception as e:
32
+ return f"Error during extraction: {str(e)}", []
33
 
34
  # Create Gradio interface
35
  interface = gr.Interface(