Spaces:
Sleeping
Sleeping
import streamlit as st | |
import trafilatura | |
import numpy as np | |
import pandas as pd | |
from tensorflow.lite.python.interpreter import Interpreter | |
import requests | |
# File paths | |
MODEL_PATH = "./model.tflite" | |
VOCAB_PATH = "./vocab.txt" | |
LABELS_PATH = "./taxonomy_v2.csv" | |
def load_vocab(): | |
with open(VOCAB_PATH, 'r') as f: | |
vocab = [line.strip() for line in f] | |
return vocab | |
def load_labels(): | |
# Load labels from the CSV file | |
taxonomy = pd.read_csv(LABELS_PATH) | |
taxonomy["ID"] = taxonomy["ID"].astype(int) | |
labels_dict = taxonomy.set_index("ID")["Topic"].to_dict() | |
return labels_dict | |
def load_model(): | |
try: | |
# Use TensorFlow Lite Interpreter | |
interpreter = Interpreter(model_path=MODEL_PATH) | |
interpreter.allocate_tensors() | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
return interpreter, input_details, output_details | |
except Exception as e: | |
st.error(f"Failed to load the model: {e}") | |
raise | |
def preprocess_text(text, vocab, max_length=128): | |
# Tokenize the text using the provided vocabulary | |
words = text.split()[:max_length] # Split and truncate | |
token_ids = [vocab.index(word) if word in vocab else vocab.index("[UNK]") for word in words] | |
token_ids = np.array(token_ids + [0] * (max_length - len(token_ids)), dtype=np.int32) # Pad to max length | |
attention_mask = np.array([1 if i < len(words) else 0 for i in range(max_length)], dtype=np.int32) | |
token_type_ids = np.zeros_like(attention_mask, dtype=np.int32) | |
return token_ids[np.newaxis, :], attention_mask[np.newaxis, :], token_type_ids[np.newaxis, :] | |
def classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids): | |
interpreter.set_tensor(input_details[0]["index"], input_word_ids) | |
interpreter.set_tensor(input_details[1]["index"], input_mask) | |
interpreter.set_tensor(input_details[2]["index"], input_type_ids) | |
interpreter.invoke() | |
output = interpreter.get_tensor(output_details[0]["index"]) | |
return output[0] | |
def fetch_url_content(url): | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36", | |
"Accept-Language": "en-US,en;q=0.9", | |
"Accept-Encoding": "gzip, deflate, br", | |
} | |
try: | |
response = requests.get(url, headers=headers, cookies={}, timeout=10) | |
if response.status_code == 200: | |
return response.text | |
else: | |
st.error(f"Failed to fetch content. Status code: {response.status_code}") | |
return None | |
except Exception as e: | |
st.error(f"Error fetching content: {e}") | |
return None | |
# Streamlit app | |
st.title("Topic Classification from URL") | |
url = st.text_input("Enter a URL:", "") | |
if url: | |
st.write("Extracting content from the URL...") | |
raw_content = fetch_url_content(url) | |
if raw_content: | |
content = trafilatura.extract(raw_content) | |
if content: | |
st.write("Content extracted successfully!") | |
st.write(content[:500]) # Display a snippet of the content | |
# Load resources | |
vocab = load_vocab() | |
labels_dict = load_labels() | |
interpreter, input_details, output_details = load_model() | |
# Preprocess content and classify | |
input_word_ids, input_mask, input_type_ids = preprocess_text(content, vocab) | |
predictions = classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids) | |
# Display classification | |
st.write("Topic Classification:") | |
sorted_indices = np.argsort(predictions)[::-1][:5] # Top 5 topics | |
for idx in sorted_indices: | |
topic = labels_dict.get(idx, "Unknown Topic") | |
st.write(f"ID: {idx} - Topic: {topic} - Score: {predictions[idx]:.4f}") | |
else: | |
st.error("Unable to extract content from the fetched HTML.") | |
else: | |
st.error("Failed to fetch the URL.") | |