João Pedro commited on
Commit
61dba08
·
1 Parent(s): 6849e5c

use id2label for human-readable label

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -6,6 +6,8 @@ from PIL import Image
6
  # Load model and processor
7
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
8
  model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
 
 
9
 
10
  st.title("Document Classification with LayoutLMv3")
11
 
@@ -33,14 +35,15 @@ if uploaded_file:
33
  )
34
  st.text(f'encoding shape: {encoding}')
35
  outputs = model(**encoding)
36
- predictions = outputs.logits.argmax(-1)
37
 
38
  # Display predictions (you may want to map indices to labels)
39
- st.write(f"Predictions: {predictions}")
40
 
41
  # User feedback section
42
  feedback = st.radio(
43
- "Is the classification correct?", ("Yes", "No")
 
44
  )
45
  if feedback == "No":
46
  correct_label = st.text_input(
 
6
  # Load model and processor
7
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
8
  model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
9
+ id2label = model.config.id2label
10
+ print(id2label)
11
 
12
  st.title("Document Classification with LayoutLMv3")
13
 
 
35
  )
36
  st.text(f'encoding shape: {encoding}')
37
  outputs = model(**encoding)
38
+ prediction = outputs.logits.argmax(-1)[0]
39
 
40
  # Display predictions (you may want to map indices to labels)
41
+ st.write(f"Prediction: {id2label[prediction]}")
42
 
43
  # User feedback section
44
  feedback = st.radio(
45
+ "Is the classification correct?", ("Yes", "No"),
46
+ key=f'prediction-{i}'
47
  )
48
  if feedback == "No":
49
  correct_label = st.text_input(