import streamlit as st from transformers import CamembertTokenizer, CamembertForSequenceClassification import torch import numpy as np @st.cache_resource def load_tokenizer(): return CamembertTokenizer.from_pretrained("camembert-base") @st.cache_resource def load_model(): return CamembertForSequenceClassification.from_pretrained("herelles/camembert-base-lupan") # Define tokenizer: tokenizer = load_tokenizer() # Load model: model = load_model() model.to('cpu') def prediction(segment_text): test_ids = [] test_attention_mask = [] # Apply the tokenizer encoding = tokenizer(segment_text, padding="longest", return_tensors="pt") # Extract IDs and Attention Mask test_ids.append(encoding['input_ids']) test_attention_mask.append(encoding['attention_mask']) test_ids = torch.cat(test_ids, dim = 0) test_attention_mask = torch.cat(test_attention_mask, dim = 0) # Forward pass, calculate logit predictions with torch.no_grad(): output = model(test_ids.to('cpu'), token_type_ids = None, attention_mask = test_attention_mask.to('cpu')) return np.argmax(output.logits.cpu().numpy()).flatten().item() def main(): st.header('Textual segments Hérelles prediction tool', divider='rainbow') segment_text = st.text_area( "Text to classify:", "Article 1 : Occupations ou utilisations du sol interdites\n\n" "1) Dans l’ensemble de la zone sont interdits :\n\n" "Les constructions destinées à l’habitation ne dépendant pas d’une exploitation agricole autres\n" "que celles visées à l’article 2 paragraphe 1).", height=170, ) if st.button('Predict'): pred_id = prediction(segment_text) if pred_id == 0: pred_label = 'Not pertinent' elif pred_id == 1: pred_label = 'Pertinent (Soft)' elif pred_id == 2: pred_label = 'Pertinent (Strict, Non-verifiable)' elif pred_id == 3: pred_label = 'Pertinent (Strict, Verifiable)' st.write("Predicted Class: ", pred_label) if __name__ == "__main__": main()