João Pedro commited on
Commit
11deda5
·
1 Parent(s): f0dc05e

add wandb logging for user feedback

Browse files
Files changed (1) hide show
  1. app.py +29 -24
app.py CHANGED
@@ -1,33 +1,24 @@
 
 
1
  import streamlit as st
 
2
  from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
3
  from pdf2image import convert_from_bytes
4
  from PIL import Image
5
 
6
- labels = [
7
- 'budget',
8
- 'email',
9
- 'form',
10
- 'handwritten',
11
- 'invoice',
12
- 'language',
13
- 'letter',
14
- 'memo',
15
- 'news article',
16
- 'questionnaire',
17
- 'resume',
18
- 'scientific publication',
19
- 'specification',
20
- ]
21
- id2label = {i: label for i, label in enumerate(labels)}
22
- label2id = {v: k for k, v in id2label.items()}
23
 
24
- processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
25
- model = LayoutLMv3ForSequenceClassification.from_pretrained(
26
- "microsoft/layoutlmv3-base",
27
- num_labels=len(labels),
28
- id2label=id2label,
29
- label2id=label2id,
30
- )
31
 
32
  st.title("Document Classification with LayoutLMv3")
33
 
@@ -36,6 +27,8 @@ uploaded_file = st.file_uploader(
36
  )
37
 
38
  if uploaded_file:
 
 
39
  if uploaded_file.type == "application/pdf":
40
  images = convert_from_bytes(uploaded_file.getvalue())
41
  else:
@@ -61,9 +54,21 @@ if uploaded_file:
61
  "Is the classification correct?", ("Yes", "No"),
62
  key=f'prediction-{i}'
63
  )
 
64
  if feedback == "No":
65
  correct_label = st.selectbox(
66
  "Please select the correct label:", labels,
67
  key=f'selectbox-{i}'
68
  )
69
  print(f'Correct label for image {i}: {correct_label}')
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
  import streamlit as st
4
+ from constants import PROJECT_NAME
5
  from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
6
  from pdf2image import convert_from_bytes
7
  from PIL import Image
8
 
9
+ wandb_api_key = os.getnev("WANDB_API_KEY")
10
+ if not wandb_api_key:
11
+ st.error(
12
+ "Couldn't find WanDB API key. Please set it up as an environemnt variable",
13
+ icon="🚨",
14
+ )
15
+ else:
16
+ wandb.login(key=wandb_api_key)
 
 
 
 
 
 
 
 
 
17
 
18
+ processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/")
19
+ model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
20
+ id2label = model.config.id2label
21
+ label2id = model.config.label2id
 
 
 
22
 
23
  st.title("Document Classification with LayoutLMv3")
24
 
 
27
  )
28
 
29
  if uploaded_file:
30
+ run = wandb.init(project=PROJECT_NAME, name='feedback-loop')
31
+
32
  if uploaded_file.type == "application/pdf":
33
  images = convert_from_bytes(uploaded_file.getvalue())
34
  else:
 
54
  "Is the classification correct?", ("Yes", "No"),
55
  key=f'prediction-{i}'
56
  )
57
+
58
  if feedback == "No":
59
  correct_label = st.selectbox(
60
  "Please select the correct label:", labels,
61
  key=f'selectbox-{i}'
62
  )
63
  print(f'Correct label for image {i}: {correct_label}')
64
+
65
+ run.log({
66
+ 'filepath': uploaded_file,
67
+ 'filetype': uploaded_file.type,
68
+ 'predicted_label': id2label[prediction],
69
+ 'predicted_label_id': prediction,
70
+ 'correct_label': correct_label,
71
+ 'correct_label_id': label2id[correct_label]
72
+ })
73
+
74
+ run.finish()