Spaces:
Sleeping
Sleeping
File size: 3,353 Bytes
11deda5 056ccc3 5639711 c1d4001 056ccc3 2bc05a3 11deda5 d73399c 9652b01 6c6f2d5 056ccc3 fd98f6f 056ccc3 ff1aca1 90ca81e 6c6f2d5 74d6a50 6c6f2d5 11deda5 6c6f2d5 fd98f6f 548ee28 fd98f6f 056ccc3 fd98f6f 180d132 63eb0c6 6c6f2d5 056ccc3 61dba08 056ccc3 fd98f6f 61dba08 fd98f6f 11deda5 fd98f6f 392dd2d d280e22 056ccc3 8f5a987 11deda5 ff1aca1 c939453 b90d0b6 11deda5 c9ae4da c939453 c9ae4da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
import wandb
import streamlit as st
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from pdf2image import convert_from_bytes
from PIL import Image
wandb_api_key = os.getenv("WANDB_API_KEY")
if not wandb_api_key:
st.error(
"Couldn't find WanDB API key. Please set it up as an environemnt variable",
icon="🚨",
)
else:
wandb.login(key=wandb_api_key)
labels = [
'budget',
'email',
'form',
'handwritten',
'invoice',
'language',
'letter',
'memo',
'news article',
'questionnaire',
'resume',
'scientific publication',
'specification',
]
id2label = {i: label for i, label in enumerate(labels)}
label2id = {v: k for k, v in id2label.items()}
if 'model' not in st.session_state:
st.session_state.model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
if 'processor' not in st.session_state:
st.session_state.processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/")
model = st.session_state.model
processor = st.session_state.processor
st.title("Document Classification with LayoutLMv3")
uploaded_file = st.file_uploader(
"Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
)
feedback_table = wandb.Table(columns=[
'image', 'filetype', 'predicted_label', 'predicted_label_id',
'correct_label', 'correct_label_id'
])
if 'wandb_run' not in st.session_state:
st.session_state.wandb_run = wandb.init(project='hydra-classifier', name='feedback-loop')
@st.cache_data
def classify_image(_image):
print(f'Encoding image with index {i}')
encoding = processor(
image,
return_tensors="pt",
truncation=True,
max_length=512,
)
print(f'Predicting image with index {i}')
outputs = model(**encoding)
prediction = outputs.logits.argmax(-1)[0].item()
return prediction
if uploaded_file:
if uploaded_file.type == "application/pdf":
images = convert_from_bytes(uploaded_file.getvalue())
else:
images = [Image.open(uploaded_file)]
for i, image in enumerate(images):
st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)
prediction = classify_image(image)
st.write(f"Prediction: {id2label[prediction]}")
feedback = st.radio(
"Is the classification correct?", ("Yes", "No"),
key=f'prediction-{i}'
)
if feedback == "No":
correct_label = st.selectbox(
"Please select the correct label:", labels,
key=f'selectbox-{i}'
)
print(f'Correct label for image {i}: {correct_label}')
# Add a button to confirm feedback and log it
if st.button(f"Add feedback for Image {i}", key=f'add-{i}'):
feedback_table.add_data(
wandb.Image(image),
uploaded_file.type,
id2label[prediction],
prediction,
correct_label,
label2id[correct_label],
)
if st.button("Submit all feedback", key=f'submit'):
run = st.session_state.wandb_run
run.log({'feedback_table': feedback_table})
run.finish()
st.success(f"Feedback submitted!")
|