File size: 1,700 Bytes
dc3e500
 
 
 
 
 
953e037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
try:
    import detectron2
except:
    import os 
    os.system('pip install git+https://github.com/facebookresearch/detectron2.git')

import streamlit as st
from PIL import Image
import torch
from transformers import LayoutLMv2ForSequenceClassification, LayoutLMv2Processor

# Chargement du modèle et du processeur
model_ft = LayoutLMv2ForSequenceClassification.from_pretrained("Tornaid/LayoutLMv2_D3_Classifier")
processor_ft = LayoutLMv2Processor.from_pretrained("Tornaid/LayoutLMv2_D3_Classifier")

label2id = {
    'budget': 0, 'form': 1, 'file_folder': 2, 'invoice': 3, 'email': 4, 
    'handwritten': 5, 'id_pieces': 6, 'advertisement': 7, 'carte postale': 8, 
    'scientific_publication': 9, 'news_article': 10, 'scientific_report': 11, 
    'resume': 12, 'letter': 13, 'presentation': 14, 'questionnaire': 15, 
    'memo': 16, 'paye': 17, 'specification': 18
}

id2label = {id: label for label, id in label2id.items()}

def predict_image_classification(image):
    """Effectue la prédiction de classification sur l'image donnée."""
    inputs = processor_ft(image, return_tensors="pt", truncation=True, max_length=512)
    outputs = model_ft(**inputs)
    prediction_index = outputs.logits.argmax(-1).item()
    return id2label[prediction_index]

# Interface Streamlit
st.title("Classification de documents")

uploaded_file = st.file_uploader("Choisissez un fichier image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
    image = Image.open(uploaded_file).convert("RGB")
    st.image(image, caption="Image chargée", use_column_width=True)
    if st.button("Classer"):
        label_pred = predict_image_classification(image)
        st.write(f"Prédiction: {label_pred}")