tc-ha commited on
Commit
eb1849d
1 Parent(s): bc94088

Add application file

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ import json
5
+ from tqdm import tqdm
6
+
7
+ import hydra
8
+ from transformers import AutoModelForQuestionAnswering, LayoutLMv2Processor, AutoTokenizer
9
+ from data_loader.data_loaders import DataLoader
10
+ from utils.util import predict_start_first
11
+
12
+ class Config():
13
+ def __init__(self):
14
+ self.data_dir = "/opt/ml/input/data/"
15
+ self.model = "layoutlmv2"
16
+ self.device = "cpu"
17
+ self.checkpoint = "microsoft/layoutlmv2-base-uncased"
18
+ self.use_ocr_library = False
19
+ self.debug = False
20
+ self.batch_data = 1
21
+ self.num_proc = 1
22
+ self.shuffle = True
23
+
24
+ self.lr = 5e-6
25
+ self.seed = 42
26
+ self.batch = 1
27
+ self.max_len = 512
28
+ self.epochs = 1000
29
+
30
+ self.fuzzy = False
31
+ self.model_name = ''
32
+
33
+ config = Config()
34
+
35
+ # Define function to make predictions
36
+ def predict(config, model, image, question):
37
+
38
+ processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
39
+ encoding = processor(image, question, return_tensors="pt")
40
+
41
+ # print(encoding.word_ids(i))
42
+
43
+ # word_ids = [[-1 if id is None else id for id in encoding.word_ids(i)] for i in range(len(question))]
44
+
45
+ # model
46
+ with torch.no_grad():
47
+ output = model(
48
+ input_ids=encoding['input_ids'],
49
+ attention_mask=encoding['attention_mask'],
50
+ token_type_ids=encoding['token_type_ids'],
51
+ bbox=encoding['bbox'], image=encoding['image']
52
+ )
53
+
54
+ predicted_start_idx, predicted_end_idx = predict_start_first(output)
55
+
56
+ answer = processor.tokenizer.decode(encoding['input_ids'][0, predicted_start_idx[0]:predicted_end_idx[0]+1])
57
+
58
+ # for batch_idx in range(1):
59
+ # answer = ""
60
+ # pred_start = predicted_start_idx[batch_idx]
61
+ # pred_end = predicted_end_idx[batch_idx]
62
+ # word_id = word_ids[batch_idx, pred_start]
63
+ # for i in range(pred_start, pred_end + 1):
64
+ # if word_id == word_ids[batch_idx, i]:
65
+ # answer += processor.tokenizer.decode(encoding['input_ids'][batch_idx][i])
66
+ # else:
67
+ # answer += ' ' + processor.tokenizer.decode(encoding['input_ids'][batch_idx][i])
68
+ # word_id = word_ids[batch_idx, i]
69
+
70
+ # answer = answer.replace('##', '')
71
+
72
+ # print(answer)
73
+
74
+ return answer
75
+
76
+ def main(config):
77
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
78
+
79
+ # Load deep learning model
80
+ checkpoint = ''
81
+ model = AutoModelForQuestionAnswering.from_pretrained('microsoft/layoutlmv2-base-uncased').to(config.device)
82
+ # model.load_state_dict(torch.load("model"))
83
+
84
+ # Create Streamlit app
85
+ st.title('Deep Learning Pipeline')
86
+ st.write('Upload an image and ask a question to get a prediction')
87
+
88
+ # Create file uploader and text input widgets
89
+ uploaded_file = st.file_uploader("Choose an image", type=['jpg', 'jpeg', 'png'])
90
+ question = st.text_input('Ask a question')
91
+
92
+ # If file is uploaded, show the image
93
+ if uploaded_file is not None:
94
+ image = Image.open(uploaded_file).convert("RGB")
95
+ st.image(image, caption='Uploaded Image', use_column_width=True)
96
+
97
+ # If question is asked and file is uploaded, make a prediction
98
+ if st.button('Get Prediction') and uploaded_file is not None and question != '':
99
+ # Preprocess the image and question as needed
100
+ # ...
101
+
102
+ # Make the prediction
103
+ with st.spinner('Predicting...'):
104
+ output = predict(config, model, image, question)
105
+
106
+ # Show the output
107
+ st.write('Output:', output)
108
+
109
+
110
+ if __name__ == '__main__':
111
+ config = Config()
112
+ main(config)
113
+
114
+
115
+
116
+
117
+