Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from utils.util_classifier import TextClassificationPipeline | |
| import time | |
| import requests | |
| import io | |
| import pdfplumber | |
| from urllib.parse import urlparse | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| def validate_url(url): | |
| try: | |
| result = urlparse(url) | |
| return all([result.scheme, result.netloc]) | |
| except: | |
| return False | |
| def download_pdf(url): | |
| try: | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', | |
| 'Accept': 'application/pdf,*/*', | |
| 'Referer': 'https://www.inter-lux.com/' | |
| } | |
| response = requests.get(url, headers=headers) | |
| response.raise_for_status() | |
| # Verify content type is PDF | |
| content_type = response.headers.get('content-type', '') | |
| if 'application/pdf' not in content_type.lower(): | |
| raise ValueError(f"URL does not point to a PDF file. Content-Type: {content_type}") | |
| return io.BytesIO(response.content) | |
| except Exception as e: | |
| st.error(f"Download error: {str(e)}") | |
| return None | |
| def extract_text(pdf_file): | |
| try: | |
| # Reset file pointer | |
| pdf_file.seek(0) | |
| with pdfplumber.open(pdf_file) as pdf: | |
| text = "" | |
| for page in pdf.pages: | |
| extracted = page.extract_text() | |
| if extracted: | |
| text += extracted + "\n" | |
| if not text.strip(): | |
| raise ValueError("No text could be extracted from the PDF") | |
| return text.strip() | |
| except Exception as e: | |
| st.error(f"Text extraction error: {str(e)}") | |
| return None | |
| def main(): | |
| st.title("π― Document Classifier") | |
| # Model selection | |
| method = "bertbased" | |
| # Initialize classifier | |
| classifier = TextClassificationPipeline(method=method) | |
| # File input tabs | |
| tab1, tab2 = st.tabs(["π URL Input", "π File Upload"]) | |
| with tab1: | |
| url = st.text_input("Enter PDF URL") | |
| process_btn = st.button("Classify Document", key="url_classify") | |
| if process_btn and url: | |
| if not validate_url(url): | |
| st.error("Please enter a valid URL") | |
| return | |
| progress_container = st.container() | |
| with progress_container: | |
| # Step 1: Downloading | |
| with st.spinner("Downloading PDF..."): | |
| pdf_file = download_pdf(url) | |
| if pdf_file is None: | |
| return | |
| st.success("PDF downloaded successfully!") | |
| # Step 2: Extracting Text | |
| with st.spinner("Extracting text from PDF..."): | |
| text = extract_text(pdf_file) | |
| if text is None or len(text.strip()) == 0: | |
| return | |
| st.success("Text extracted successfully!") | |
| with st.expander("View Extracted Text"): | |
| st.text(text[:500] + "..." if len(text) > 500 else text) | |
| # Step 3: Classification | |
| with st.spinner("Classifying document..."): | |
| result = classifier.predict(text, return_probability=True) | |
| if isinstance(result, list): | |
| result = result[0] | |
| # Display results | |
| def create_gauge_chart(confidence): | |
| """Create a gauge chart for confidence score""" | |
| fig = go.Figure(go.Indicator( | |
| mode = "gauge+number+delta", | |
| value = confidence * 100, | |
| domain = {'x': [0, 1], 'y': [0, 1]}, | |
| gauge = { | |
| 'axis': {'range': [None, 100], 'tickwidth': 1, 'tickcolor': "darkblue"}, | |
| 'bar': {'color': "darkblue"}, | |
| 'bgcolor': "white", | |
| 'borderwidth': 2, | |
| 'bordercolor': "gray", | |
| 'steps': [ | |
| {'range': [0, 50], 'color': '#FF9999'}, | |
| {'range': [50, 75], 'color': '#FFCC99'}, | |
| {'range': [75, 100], 'color': '#99FF99'} | |
| ], | |
| }, | |
| title = {'text': "Confidence Score"} | |
| )) | |
| fig.update_layout( | |
| height=300, | |
| margin=dict(l=10, r=10, t=50, b=10), | |
| paper_bgcolor='rgba(0,0,0,0)', | |
| font={'color': "darkblue", 'family': "Arial"} | |
| ) | |
| return fig | |
| def create_probability_chart(probabilities): | |
| """Create a horizontal bar chart for probability distribution""" | |
| labels = list(probabilities.keys()) | |
| values = list(probabilities.values()) | |
| fig = go.Figure() | |
| # Add bars | |
| fig.add_trace(go.Bar( | |
| y=labels, | |
| x=[v * 100 for v in values], | |
| orientation='h', | |
| marker=dict( | |
| color=[px.colors.sequential.Blues[i] for i in range(2, len(labels) + 2)], | |
| line=dict(color='rgba(0,0,0,0.8)', width=2) | |
| ), | |
| text=[f'{v:.1f}%' for v in [v * 100 for v in values]], | |
| textposition='auto', | |
| )) | |
| # Update layout | |
| fig.update_layout( | |
| title=dict( | |
| text='Probability Distribution', | |
| y=0.95, | |
| x=0.5, | |
| xanchor='center', | |
| yanchor='top', | |
| font=dict(size=20, color='darkblue') | |
| ), | |
| xaxis_title="Probability (%)", | |
| yaxis_title="Categories", | |
| height=400, | |
| margin=dict(l=20, r=20, t=70, b=20), | |
| paper_bgcolor='rgba(0,0,0,0)', | |
| plot_bgcolor='rgba(0,0,0,0)', | |
| font=dict(family="Arial", size=14), | |
| showlegend=False | |
| ) | |
| # Update axes | |
| fig.update_xaxes( | |
| range=[0, 100], | |
| gridcolor='rgba(0,0,0,0.1)', | |
| zerolinecolor='rgba(0,0,0,0.2)' | |
| ) | |
| fig.update_yaxes( | |
| gridcolor='rgba(0,0,0,0.1)', | |
| zerolinecolor='rgba(0,0,0,0.2)' | |
| ) | |
| return fig | |
| # Update the results display section | |
| def display_results(result): | |
| """Display classification results with modern visualizations""" | |
| # Create three columns for the results | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| # Predicted Category Card | |
| st.markdown(""" | |
| <div style=' | |
| background-color: white; | |
| padding: 20px; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| text-align: center; | |
| margin-bottom: 20px; | |
| '> | |
| <h4 style='color: #1f77b4; margin-bottom: 10px;'>Predicted Category</h4> | |
| <p style=' | |
| font-size: 24px; | |
| font-weight: bold; | |
| color: #2c3e50; | |
| margin: 0; | |
| padding: 10px; | |
| background-color: #f8f9fa; | |
| border-radius: 5px; | |
| '>{}</p> | |
| </div> | |
| """.format(result['predicted_label']), unsafe_allow_html=True) | |
| # Confidence Gauge | |
| st.plotly_chart(create_gauge_chart(result['confidence']), use_container_width=True) | |
| with col2: | |
| # Probability Distribution | |
| st.plotly_chart(create_probability_chart(result['probabilities']), use_container_width=True) | |
| # Add metadata section | |
| with st.expander("π Classification Details"): | |
| st.markdown(f""" | |
| - **Model Type**: {result['model_type'].title()} | |
| - **Document Length**: {len(result['text'])} characters | |
| """) | |
| # Update the main classification results section | |
| # Replace the existing results display with: | |
| st.markdown("### π Classification Results") | |
| display_results(result) | |
| with tab2: | |
| uploaded_file = st.file_uploader("Upload PDF file", type="pdf") | |
| process_btn = st.button("Classify Document", key="file_classify") | |
| if process_btn and uploaded_file: | |
| with st.spinner("Processing uploaded PDF..."): | |
| text = extract_text(uploaded_file) | |
| if text is None: | |
| return | |
| result = classifier.predict(text, return_probability=True) | |
| if isinstance(result, list): | |
| result = result[0] | |
| # Display results (same as URL tab) | |
| st.markdown("### π Classification Results") | |
| confidence = result['confidence'] | |
| st.markdown(f""" | |
| <div class="confidence-meter"> | |
| <div class="meter-fill" style="width: {confidence*100}%"></div> | |
| <span class="meter-text">{confidence:.1%} Confident</span> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown(f""" | |
| <div class="result-card"> | |
| <h4>Predicted Category</h4> | |
| <p class="prediction">{result['predicted_label']}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("#### Probability Distribution") | |
| for label, prob in result['probabilities'].items(): | |
| st.markdown(f""" | |
| <div class="prob-bar"> | |
| <span class="label">{label}</span> | |
| <div class="bar"> | |
| <div class="fill" style="width: {prob*100}%"></div> | |
| </div> | |
| <span class="value">{prob:.1%}</span> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| main() |