Amith Adiraju commited on
Commit
11b899a
·
1 Parent(s): 8dc61da

Feature branch to add feature for asynchronous inference with cv and llm models, for improving latency.

Browse files

1. Added logic to separate main file from inference file, redirection in streamlit added. ( main_page and model_inference )
2. Moved loading of models to model_inference function.
3. Added asynchronous code for most of the inference and pre-processing functions to decrease latency.
4. Added threadpool executor to distribute llm inference, this saves about 25% latency time with llm.
5. Added logic to display inference of certain items that are done quickly than others, improves user experince, they need not wait for all items to be done to see result.
6. Added stateful application with page toggling.

Signed-off-by: Amith Adiraju <amithadiraju@Amiths-Laptop.local>

Documented most of the functions and cleaned up code.

Signed-off-by: Amith Adiraju <amithadiraju@Amiths-Laptop.local>

.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  misc.txt
2
  test_cas.py
3
- test_train_llm.py
 
 
1
  misc.txt
2
  test_cas.py
3
+ test_train_llm.py
4
+ redir_app.py
app.py CHANGED
@@ -1,62 +1,166 @@
1
  import streamlit as st
 
 
 
 
2
 
3
  from inference.translate import (
4
  extract_filter_img,
5
- transcribe_menu_model,
6
- load_models
7
  )
8
 
9
  from inference.config import DEBUG_MODE
10
  from PIL import Image
11
  import time
12
 
13
- # Streamlit app
14
- st.title("Image Upload and Processing")
15
 
 
 
 
16
 
17
- # Using open source text detector, LLM for explaining items
18
- text_extractor, \
19
- item_tokenizer,item_summarizer = load_models(item_summarizer = "google/flan-t5-large")
20
 
21
- # Streamlit function to upload an image from any device
22
- uploaded_file = st.file_uploader("Choose an image...",
23
- type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
25
 
26
- # Submit button
27
- if uploaded_file is not None:
28
- image = Image.open(uploaded_file)
29
 
30
- # Only show if user wants to see
31
- if st.checkbox('Show Uploaded Image'):
32
- st.image(image,
33
- caption='Uploaded Image',
34
- use_column_width=True)
 
 
35
 
36
  # Submit button
37
- if st.button("Submit"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  msg1 = st.empty()
40
  msg1.write("Pre-processing and extracting text out of your image ....")
41
  st_filter = time.perf_counter()
 
42
  # Call the extract_filter_img function
43
- filtered_text = extract_filter_img(image, text_extractor)
44
  en_filter = time.perf_counter()
45
 
46
  num_items_detected = len(filtered_text)
 
47
  if num_items_detected == 0:
48
  st.write("We couldn't detect any menu items ( indian for now ) from your image, please try a different image.")
49
-
50
  elif num_items_detected > 0:
51
- st.write(f"Detected {num_items_detected} menu items ( indian ) from your input image ... ")
52
 
53
  msg2 = st.empty()
54
  msg2.write("All pre-processing done, transcribing your menu items now ....")
55
  st_trans_llm = time.perf_counter()
56
- translated_text_dict = transcribe_menu_model(menu_texts=filtered_text,
57
- text_tokenizer=item_tokenizer,
58
- text_summarizer=item_summarizer
59
- )
60
 
61
  msg3 = st.empty()
62
  msg3.write("Done transcribing ... ")
@@ -74,5 +178,27 @@ if uploaded_file is not None:
74
  st.write("Time took to summarize by LLM {}".format(llm_time_sec))
75
  st.write('Overall time taken in seconds: {}'.format(total_time_sec))
76
 
77
- st.table(translated_text_dict)
78
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from streamlit import session_state as sst
3
+ from typing import List, Optional
4
+ import asyncio
5
+ import pandas as pd
6
 
7
  from inference.translate import (
8
  extract_filter_img,
9
+ transcribe_menu_model
 
10
  )
11
 
12
  from inference.config import DEBUG_MODE
13
  from PIL import Image
14
  import time
15
 
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
+ import os
18
 
19
+ # Setting workers to be 70% of all available virtual cpus in system
20
+ cpu_count = os.cpu_count()
21
+ pool = ThreadPoolExecutor(max_workers=int(cpu_count*0.7) )
22
 
23
+ # Initialize session state variable to start with home page
24
+ if "page" not in sst:
25
+ sst["page"] = "Home"
26
 
27
+ def navigate_to(page: str) -> None:
28
+ """
29
+ Function to set the current page in the state of streamlit. A helper for
30
+ simulating navigation in streamlit.
31
+
32
+ Parameters:
33
+ page: str, required.
34
+
35
+ Returns:
36
+ None
37
+ """
38
+
39
+ sst["page"] = page
40
+
41
+ async def main_page() -> None:
42
+ """
43
+ Function that contains content of main page i.e., image uploader and submit button to navigate to next page.
44
+ Upon submit , control goes to model inference 'page'.
45
 
46
+ Parameters:
47
+ None
48
+
49
+ Returns:
50
+ None
51
+ """
52
 
53
+ # Streamlit app
54
+ first_title = st.empty()
55
+ first_title.title("App that explains your menu items ")
56
 
57
+
58
+ # Streamlit function to upload an image from any device
59
+ uploaded_file = st.file_uploader("Choose an image...",
60
+ type=["jpg", "jpeg", "png"])
61
+
62
+ # Remove preivous states' value of input image if it exists
63
+ sst.pop('input_image', None)
64
 
65
  # Submit button
66
+ if uploaded_file is not None:
67
+ image = Image.open(uploaded_file)
68
+
69
+ # Only show if user wants to see
70
+ if st.checkbox('Show Uploaded Image'):
71
+ st.image(image,
72
+ caption='Uploaded Image',
73
+ use_column_width=True)
74
+
75
+ sst["input_image"] = image
76
+
77
+ # Submit button
78
+ st.button("Submit",
79
+ on_click = navigate_to,
80
+ args = ("Inference",))
81
+
82
+
83
+ st.info("""This application is for education purposes only. It uses AI, hence it's dietary
84
+ recommendations are not to be taken as medical advice, author doesn't bear responsibility
85
+ for incorrect dietary recommendations. Please proceed with caution.
86
+ """)
87
+
88
+
89
+ async def dist_llm_inference(inp_texts: List[str]) -> None:
90
+
91
+ """
92
+ Function that performs concurrent LLM inference using threadpool. It displays
93
+ results of those threads that are done with execution, as a dynamic row to streamlit table, rather than
94
+ waiting for all threads to be done.
95
+
96
+ Parameters:
97
+ inp_texts: List[str], required -> List of strings, containing item names of a menu in english.
98
+
99
+ Returns:
100
+ None
101
+ """
102
+
103
+ df = pd.DataFrame([('ITEM NAME', 'EXPLANATION')]
104
+ )
105
+
106
+ sl_table = st.table(df)
107
+ tp_futures = { pool.submit(transcribe_menu_model, mi): mi for mi in inp_texts }
108
+
109
+ for tpftr in as_completed(tp_futures):
110
+
111
+ item = tp_futures[tpftr]
112
+
113
+ try:
114
+ exp = tpftr.result()
115
+ sl_table.add_rows([(item,exp)] )
116
+
117
+ except Exception as e:
118
+ print("Could not add a new row dynamically, because of this error:", e)
119
+
120
+ return
121
+
122
+
123
+ async def model_inference():
124
+
125
+ """
126
+ Function that pre-processes input text from state variables, does concurrent inference
127
+ and toggles state between pages if needed.
128
+
129
+ Parameters:
130
+ None
131
+ Returns:
132
+ None
133
+
134
+ """
135
+
136
+ second_title = st.empty()
137
+ second_title.title(" Using ML to explain your menu items ... ")
138
+
139
+ if "input_image" in sst:
140
+
141
+ image = sst["input_image"]
142
 
143
  msg1 = st.empty()
144
  msg1.write("Pre-processing and extracting text out of your image ....")
145
  st_filter = time.perf_counter()
146
+
147
  # Call the extract_filter_img function
148
+ filtered_text = await extract_filter_img(image)
149
  en_filter = time.perf_counter()
150
 
151
  num_items_detected = len(filtered_text)
152
+
153
  if num_items_detected == 0:
154
  st.write("We couldn't detect any menu items ( indian for now ) from your image, please try a different image.")
155
+
156
  elif num_items_detected > 0:
157
+ st.write(f"Detected {num_items_detected} menu items from your input image ... ")
158
 
159
  msg2 = st.empty()
160
  msg2.write("All pre-processing done, transcribing your menu items now ....")
161
  st_trans_llm = time.perf_counter()
162
+
163
+ await dist_llm_inference(filtered_text)
 
 
164
 
165
  msg3 = st.empty()
166
  msg3.write("Done transcribing ... ")
 
178
  st.write("Time took to summarize by LLM {}".format(llm_time_sec))
179
  st.write('Overall time taken in seconds: {}'.format(total_time_sec))
180
 
181
+
182
+ st.button("translate another",
183
+ on_click=navigate_to,
184
+ args=("Home",))
185
+
186
+ else:
187
+ st.write("Looks like image upload failed, please try uploading it again ... ")
188
+
189
+
190
+ async def main():
191
+ """
192
+ Function that toggles between pages based on state variables.
193
+
194
+ Parameters:
195
+ None
196
+ Returns:
197
+ None
198
+ """
199
+ if sst["page"] == "Home":
200
+ await main_page()
201
+ elif sst["page"] == "Inference":
202
+ await model_inference()
203
+
204
+ asyncio.run(main())
inference/config.py CHANGED
@@ -29,6 +29,5 @@ Based on Item and explanation pairs provided above, provide similar explanation
29
  Item ->
30
  """
31
 
32
- DEBUG_MODE = True
33
-
34
  DEVICE = 'cpu'
 
29
  Item ->
30
  """
31
 
32
+ DEBUG_MODE = False
 
33
  DEVICE = 'cpu'
inference/preprocess_image.py CHANGED
@@ -11,6 +11,18 @@ import re
11
 
12
 
13
  def preprocess_text(sentence: AnyStr) -> AnyStr:
 
 
 
 
 
 
 
 
 
 
 
 
14
  sentence=sentence.lower().replace('{html}',"")
15
  cleanr = re.compile('<.*?>')
16
  cleantext = re.sub(cleanr, '', sentence)
@@ -27,15 +39,25 @@ def preprocess_text(sentence: AnyStr) -> AnyStr:
27
  return return_txt
28
 
29
  def image_to_np_arr(image) -> np.array:
 
 
 
 
 
 
 
 
 
 
 
30
  return np.array(image)
31
 
32
- def process_extracted_text(raw_extrc_text: List[Tuple]) -> List[AnyStr]:
33
 
34
  output_texts = []
35
  for _, extr_text, _ in raw_extrc_text:
36
  # remove all numbers, special characters from a string
37
  prcsd_txt = preprocess_text(extr_text)
38
-
39
- if len(prcsd_txt.split(" ") ) > 2: output_texts.append(prcsd_txt)
40
 
41
  return output_texts
 
11
 
12
 
13
  def preprocess_text(sentence: AnyStr) -> AnyStr:
14
+
15
+ """
16
+ Function that pre-processes input text by removing special characters, hyper links,
17
+ numbers and by removing stop words
18
+
19
+ Parameters:
20
+ sentence: str, required -> A raw string which may have stop words, special chars etc.
21
+
22
+ Returns:
23
+ return_txt: str -> A clean string with all aforementioned, removed.
24
+ """
25
+
26
  sentence=sentence.lower().replace('{html}',"")
27
  cleanr = re.compile('<.*?>')
28
  cleantext = re.sub(cleanr, '', sentence)
 
39
  return return_txt
40
 
41
  def image_to_np_arr(image) -> np.array:
42
+
43
+ """
44
+ Function that converts a byte array image into a floating pointer numpy array.
45
+
46
+ Parameters:
47
+ inp_texts: List[str], required -> List of strings, containing item names of a menu in english.
48
+
49
+ Returns:
50
+ np.ndarray
51
+ """
52
+
53
  return np.array(image)
54
 
55
+ async def process_extracted_text(raw_extrc_text: List[Tuple]) -> List[AnyStr]:
56
 
57
  output_texts = []
58
  for _, extr_text, _ in raw_extrc_text:
59
  # remove all numbers, special characters from a string
60
  prcsd_txt = preprocess_text(extr_text)
61
+ if len(prcsd_txt.split(" ") ) >= 2: output_texts.append(prcsd_txt)
 
62
 
63
  return output_texts
inference/translate.py CHANGED
@@ -14,9 +14,21 @@ import time
14
  use_gpu = True
15
  if DEVICE == 'cpu': use_gpu = False
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Define your extract_filter_img function
19
- def extract_filter_img(image, text_extractor) -> Dict:
20
 
21
  """
22
  1. Convert Image to numpy array
@@ -48,7 +60,8 @@ def extract_filter_img(image, text_extractor) -> Dict:
48
  if i in ind_add_delays:
49
  time.sleep(0.5)
50
 
51
- result = func(result)
 
52
 
53
  status_message.write(end_message)
54
 
@@ -63,42 +76,22 @@ def extract_filter_img(image, text_extractor) -> Dict:
63
  return result
64
 
65
 
66
- def transcribe_menu_model(menu_texts: List[AnyStr],
67
- text_summarizer = None,
68
- text_tokenizer = None) -> Dict:
69
-
70
- summarized_menu_items = {}
71
 
72
- for mi in menu_texts:
73
- if not text_summarizer:
74
- raise NotImplementedError(""" """)
75
-
76
- else:
77
- prompt_item = INSTRUCTION_PROMPT + " " + mi + """
78
 
79
 
80
  """
81
- input_ids = text_tokenizer(prompt_item, return_tensors="pt").input_ids
82
-
83
- outputs = text_summarizer.generate(input_ids,
84
- max_new_tokens = 512
85
- )
86
-
87
- summarized_menu_items[mi] = text_tokenizer.decode(
88
- outputs[0],
89
- skip_special_tokens = True
90
- )
91
 
92
- return summarized_menu_items
93
-
94
- def load_models(item_summarizer: AnyStr) -> Tuple:
95
- text_extractor = easyocr.Reader(['en'],
96
- gpu = use_gpu
97
- )
98
- tokenizer = T5Tokenizer.from_pretrained(item_summarizer)
99
- model = T5ForConditionalGeneration.from_pretrained(item_summarizer)
100
-
101
- return (text_extractor, tokenizer, model)
102
 
103
  def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]:
104
  return extrc_str
 
14
  use_gpu = True
15
  if DEVICE == 'cpu': use_gpu = False
16
 
17
+ @st.cache_resource
18
+ def load_models(item_summarizer: AnyStr) -> Tuple:
19
+ text_extractor = easyocr.Reader(['en'],
20
+ gpu = use_gpu
21
+ )
22
+ tokenizer = T5Tokenizer.from_pretrained(item_summarizer)
23
+ model = T5ForConditionalGeneration.from_pretrained(item_summarizer)
24
+
25
+ return (text_extractor, tokenizer, model)
26
+
27
+ text_extractor,item_tokenizer,item_summarizer = load_models(item_summarizer = "google/flan-t5-large")
28
+
29
 
30
  # Define your extract_filter_img function
31
+ async def extract_filter_img(image) -> Dict:
32
 
33
  """
34
  1. Convert Image to numpy array
 
60
  if i in ind_add_delays:
61
  time.sleep(0.5)
62
 
63
+ if i == 2: result = await func(result)
64
+ else: result = func(result)
65
 
66
  status_message.write(end_message)
67
 
 
76
  return result
77
 
78
 
79
+ def transcribe_menu_model(menu_text: List[AnyStr]) -> Dict:
 
 
 
 
80
 
81
+ prompt_item = INSTRUCTION_PROMPT + " " + menu_text + """
 
 
 
 
 
82
 
83
 
84
  """
85
+ input_ids = item_tokenizer(prompt_item, return_tensors="pt").input_ids
 
 
 
 
 
 
 
 
 
86
 
87
+ outputs = item_summarizer.generate(input_ids,
88
+ max_new_tokens = 512
89
+ )
90
+
91
+ return item_tokenizer.decode(
92
+ outputs[0],
93
+ skip_special_tokens = True
94
+ )
 
 
95
 
96
  def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]:
97
  return extrc_str