import streamlit as st import trafilatura import numpy as np import pandas as pd from tensorflow.lite.python.interpreter import Interpreter import requests # File paths MODEL_PATH = "./model.tflite" VOCAB_PATH = "./vocab.txt" LABELS_PATH = "./taxonomy_v2.csv" @st.cache_resource def load_vocab(): with open(VOCAB_PATH, 'r') as f: vocab = [line.strip() for line in f] return vocab @st.cache_resource def load_labels(): # Load labels from the CSV file taxonomy = pd.read_csv(LABELS_PATH) taxonomy["ID"] = taxonomy["ID"].astype(int) labels_dict = taxonomy.set_index("ID")["Topic"].to_dict() return labels_dict @st.cache_resource def load_model(): try: # Use TensorFlow Lite Interpreter interpreter = Interpreter(model_path=MODEL_PATH) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() return interpreter, input_details, output_details except Exception as e: st.error(f"Failed to load the model: {e}") raise def preprocess_text(text, vocab, max_length=128): # Tokenize the text using the provided vocabulary words = text.split()[:max_length] # Split and truncate token_ids = [vocab.index(word) if word in vocab else vocab.index("[UNK]") for word in words] token_ids = np.array(token_ids + [0] * (max_length - len(token_ids)), dtype=np.int32) # Pad to max length attention_mask = np.array([1 if i < len(words) else 0 for i in range(max_length)], dtype=np.int32) token_type_ids = np.zeros_like(attention_mask, dtype=np.int32) return token_ids[np.newaxis, :], attention_mask[np.newaxis, :], token_type_ids[np.newaxis, :] def classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids): interpreter.set_tensor(input_details[0]["index"], input_word_ids) interpreter.set_tensor(input_details[1]["index"], input_mask) interpreter.set_tensor(input_details[2]["index"], input_type_ids) interpreter.invoke() output = interpreter.get_tensor(output_details[0]["index"]) return output[0] def fetch_url_content(url): headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36", "Accept-Language": "en-US,en;q=0.9", "Accept-Encoding": "gzip, deflate, br", } try: response = requests.get(url, headers=headers, cookies={}, timeout=10) if response.status_code == 200: return response.text else: st.error(f"Failed to fetch content. Status code: {response.status_code}") return None except Exception as e: st.error(f"Error fetching content: {e}") return None # Streamlit app st.title("Topic Classification from URL") url = st.text_input("Enter a URL:", "") if url: st.write("Extracting content from the URL...") raw_content = fetch_url_content(url) if raw_content: content = trafilatura.extract(raw_content) if content: st.write("Content extracted successfully!") st.write(content[:500]) # Display a snippet of the content # Load resources vocab = load_vocab() labels_dict = load_labels() interpreter, input_details, output_details = load_model() # Preprocess content and classify input_word_ids, input_mask, input_type_ids = preprocess_text(content, vocab) predictions = classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids) # Display classification st.write("Topic Classification:") sorted_indices = np.argsort(predictions)[::-1][:5] # Top 5 topics for idx in sorted_indices: topic = labels_dict.get(idx, "Unknown Topic") st.write(f"ID: {idx} - Topic: {topic} - Score: {predictions[idx]:.4f}") else: st.error("Unable to extract content from the fetched HTML.") else: st.error("Failed to fetch the URL.")