browsing-topics / app.py
dejanseo's picture
Update app.py
bf7efe3 verified
raw
history blame
4.17 kB
import streamlit as st
import trafilatura
import numpy as np
import pandas as pd
from tensorflow.lite.python.interpreter import Interpreter
import requests
# File paths
MODEL_PATH = "./model.tflite"
VOCAB_PATH = "./vocab.txt"
LABELS_PATH = "./taxonomy_v2.csv"
@st.cache_resource
def load_vocab():
with open(VOCAB_PATH, 'r') as f:
vocab = [line.strip() for line in f]
return vocab
@st.cache_resource
def load_labels():
# Load labels from the CSV file
taxonomy = pd.read_csv(LABELS_PATH)
taxonomy["ID"] = taxonomy["ID"].astype(int)
labels_dict = taxonomy.set_index("ID")["Topic"].to_dict()
return labels_dict
@st.cache_resource
def load_model():
try:
# Use TensorFlow Lite Interpreter
interpreter = Interpreter(model_path=MODEL_PATH)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
return interpreter, input_details, output_details
except Exception as e:
st.error(f"Failed to load the model: {e}")
raise
def preprocess_text(text, vocab, max_length=128):
# Tokenize the text using the provided vocabulary
words = text.split()[:max_length] # Split and truncate
token_ids = [vocab.index(word) if word in vocab else vocab.index("[UNK]") for word in words]
token_ids = np.array(token_ids + [0] * (max_length - len(token_ids)), dtype=np.int32) # Pad to max length
attention_mask = np.array([1 if i < len(words) else 0 for i in range(max_length)], dtype=np.int32)
token_type_ids = np.zeros_like(attention_mask, dtype=np.int32)
return token_ids[np.newaxis, :], attention_mask[np.newaxis, :], token_type_ids[np.newaxis, :]
def classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids):
interpreter.set_tensor(input_details[0]["index"], input_word_ids)
interpreter.set_tensor(input_details[1]["index"], input_mask)
interpreter.set_tensor(input_details[2]["index"], input_type_ids)
interpreter.invoke()
output = interpreter.get_tensor(output_details[0]["index"])
return output[0]
def fetch_url_content(url):
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36",
"Accept-Language": "en-US,en;q=0.9",
"Accept-Encoding": "gzip, deflate, br",
}
try:
response = requests.get(url, headers=headers, cookies={}, timeout=10)
if response.status_code == 200:
return response.text
else:
st.error(f"Failed to fetch content. Status code: {response.status_code}")
return None
except Exception as e:
st.error(f"Error fetching content: {e}")
return None
# Streamlit app
st.title("Topic Classification from URL")
url = st.text_input("Enter a URL:", "")
if url:
st.write("Extracting content from the URL...")
raw_content = fetch_url_content(url)
if raw_content:
content = trafilatura.extract(raw_content)
if content:
st.write("Content extracted successfully!")
st.write(content[:500]) # Display a snippet of the content
# Load resources
vocab = load_vocab()
labels_dict = load_labels()
interpreter, input_details, output_details = load_model()
# Preprocess content and classify
input_word_ids, input_mask, input_type_ids = preprocess_text(content, vocab)
predictions = classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids)
# Display classification
st.write("Topic Classification:")
sorted_indices = np.argsort(predictions)[::-1][:5] # Top 5 topics
for idx in sorted_indices:
topic = labels_dict.get(idx, "Unknown Topic")
st.write(f"ID: {idx} - Topic: {topic} - Score: {predictions[idx]:.4f}")
else:
st.error("Unable to extract content from the fetched HTML.")
else:
st.error("Failed to fetch the URL.")