File size: 3,645 Bytes
5c9bf40
51dfab7
5c9bf40
 
51dfab7
5c9bf40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb1849d
5c9bf40
 
 
 
 
eb1849d
5c9bf40
 
eb1849d
5c9bf40
eb1849d
5c9bf40
 
 
b029c87
5c9bf40
 
b029c87
5c9bf40
b029c87
5c9bf40
b029c87
5c9bf40
 
 
 
b029c87
5c9bf40
 
 
 
 
b029c87
5c9bf40
b029c87
5c9bf40
b029c87
5c9bf40
 
eb1849d
5c9bf40
 
eb1849d
5c9bf40
 
 
 
 
 
 
 
eb1849d
5c9bf40
eb1849d
5c9bf40
eb1849d
5c9bf40
eb1849d
5c9bf40
eb1849d
5c9bf40
 
 
 
eb1849d
5c9bf40
 
 
eb1849d
5c9bf40
 
 
eb1849d
5c9bf40
 
 
 
eb1849d
5c9bf40
 
 
 
eb1849d
5c9bf40
 
 
eb1849d
5c9bf40
 
eb1849d
 
5c9bf40
 
 
eb1849d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# import streamlit as st

# x = st.slider('Select a value')
# st.write(x, 'squared is', x * x)

import streamlit as st
import torch
from PIL import Image
import json
from tqdm import tqdm

from transformers import AutoModelForQuestionAnswering, LayoutLMv2Processor, AutoTokenizer

class Config():
    def __init__(self):
        self.data_dir = "/opt/ml/input/data/"
        self.model = "layoutlmv2"
        self.device = "cpu"
        self.checkpoint = "microsoft/layoutlmv2-base-uncased"
        self.use_ocr_library = False
        self.debug = False
        self.batch_data = 1
        self.num_proc = 1
        self.shuffle = True
        
        self.lr = 5e-6
        self.seed = 42
        self.batch = 1
        self.max_len = 512
        self.epochs = 1000
        
        self.fuzzy = False
        self.model_name = ''
        
config = Config()

def predict_start_first(outputs):
    start_logits = outputs.start_logits
    end_logits = outputs.end_logits
    
    predicted_start_idx_list = []
    predicted_end_idx_list = []
    
    start_position = start_logits.argmax(1)

    for i in range(len(start_logits)):
        
        start = start_position[i]
        predicted_start_idx_list.append(start)
        max_score = -float('inf')
        predicted_end_idx = 0
        
        for end in range(start, len(end_logits[i])):
            score = end_logits[i][end]
            if score > max_score:
                max_score = score
                predicted_end_idx = end
                
        predicted_end_idx_list.append(predicted_end_idx)
    
    return predicted_start_idx_list, predicted_end_idx_list

# Define function to make predictions
def predict(config, model, image, question):
    
    processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
    encoding = processor(image, question, return_tensors="pt")
    
    # model
    with torch.no_grad():
        output = model(
            input_ids=encoding['input_ids'], 
            attention_mask=encoding['attention_mask'],
            token_type_ids=encoding['token_type_ids'],
            bbox=encoding['bbox'], image=encoding['image']
        )
    
    predicted_start_idx, predicted_end_idx = predict_start_first(output)
    
    answer = processor.tokenizer.decode(encoding['input_ids'][0, predicted_start_idx[0]:predicted_end_idx[0]+1])
      
    return answer

def main(config):

    # Load deep learning model
    checkpoint = ''
    model = AutoModelForQuestionAnswering.from_pretrained('microsoft/layoutlmv2-base-uncased').to(config.device)
    # model.load_state_dict(torch.load("model")) 

    # Create Streamlit app
    st.title('Deep Learning Pipeline')
    st.write('Upload an image and ask a question to get a prediction')

    # Create file uploader and text input widgets
    uploaded_file = st.file_uploader("Choose an image", type=['jpg', 'jpeg', 'png'])
    question = st.text_input('Ask a question')

    # If file is uploaded, show the image
    if uploaded_file is not None:
        image = Image.open(uploaded_file).convert("RGB")
        st.image(image, caption='Uploaded Image', use_column_width=True)

    # If question is asked and file is uploaded, make a prediction
    if st.button('Get Prediction') and uploaded_file is not None and question != '':
        # Preprocess the image and question as needed
        # ...

        # Make the prediction
        with st.spinner('Predicting...'):
            output = predict(config, model, image, question)

        # Show the output
        st.write('Output:', output)


if __name__ == '__main__':
    config = Config()
    main(config)