Spaces:
Runtime error
Runtime error
try: | |
import detectron2 | |
except: | |
import os | |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git') | |
import streamlit as st | |
from PIL import Image | |
import torch | |
from transformers import LayoutLMv2ForSequenceClassification, LayoutLMv2Processor | |
# Chargement du modèle et du processeur | |
model_ft = LayoutLMv2ForSequenceClassification.from_pretrained("Tornaid/LayoutLMv2_D3_Classifier") | |
processor_ft = LayoutLMv2Processor.from_pretrained("Tornaid/LayoutLMv2_D3_Classifier") | |
label2id = { | |
'budget': 0, 'form': 1, 'file_folder': 2, 'invoice': 3, 'email': 4, | |
'handwritten': 5, 'id_pieces': 6, 'advertisement': 7, 'carte postale': 8, | |
'scientific_publication': 9, 'news_article': 10, 'scientific_report': 11, | |
'resume': 12, 'letter': 13, 'presentation': 14, 'questionnaire': 15, | |
'memo': 16, 'paye': 17, 'specification': 18 | |
} | |
id2label = {id: label for label, id in label2id.items()} | |
def predict_image_classification(image): | |
"""Effectue la prédiction de classification sur l'image donnée.""" | |
inputs = processor_ft(image, return_tensors="pt", truncation=True, max_length=512) | |
outputs = model_ft(**inputs) | |
prediction_index = outputs.logits.argmax(-1).item() | |
return id2label[prediction_index] | |
# Interface Streamlit | |
st.title("Classification de documents") | |
uploaded_file = st.file_uploader("Choisissez un fichier image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file).convert("RGB") | |
st.image(image, caption="Image chargée", use_column_width=True) | |
if st.button("Classer"): | |
label_pred = predict_image_classification(image) | |
st.write(f"Prédiction: {label_pred}") |