cdelaunay's picture
Update app.py
dc3e500 verified
try:
import detectron2
except:
import os
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
import streamlit as st
from PIL import Image
import torch
from transformers import LayoutLMv2ForSequenceClassification, LayoutLMv2Processor
# Chargement du modèle et du processeur
model_ft = LayoutLMv2ForSequenceClassification.from_pretrained("Tornaid/LayoutLMv2_D3_Classifier")
processor_ft = LayoutLMv2Processor.from_pretrained("Tornaid/LayoutLMv2_D3_Classifier")
label2id = {
'budget': 0, 'form': 1, 'file_folder': 2, 'invoice': 3, 'email': 4,
'handwritten': 5, 'id_pieces': 6, 'advertisement': 7, 'carte postale': 8,
'scientific_publication': 9, 'news_article': 10, 'scientific_report': 11,
'resume': 12, 'letter': 13, 'presentation': 14, 'questionnaire': 15,
'memo': 16, 'paye': 17, 'specification': 18
}
id2label = {id: label for label, id in label2id.items()}
def predict_image_classification(image):
"""Effectue la prédiction de classification sur l'image donnée."""
inputs = processor_ft(image, return_tensors="pt", truncation=True, max_length=512)
outputs = model_ft(**inputs)
prediction_index = outputs.logits.argmax(-1).item()
return id2label[prediction_index]
# Interface Streamlit
st.title("Classification de documents")
uploaded_file = st.file_uploader("Choisissez un fichier image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Image chargée", use_column_width=True)
if st.button("Classer"):
label_pred = predict_image_classification(image)
st.write(f"Prédiction: {label_pred}")