SCANSKY's picture
Update handler.py
0ae5234 verified
raw
history blame contribute delete
3.02 kB
import json
from bertopic import BERTopic
class EndpointHandler:
def __init__(self, model_path="SCANSKY/BERTopic-Tourism-Chinese"):
"""
Initialize the handler. Load the BERTopic model from Hugging Face.
"""
self.topic_model = BERTopic.load(model_path)
def preprocess(self, data):
"""
Preprocess the incoming request data.
- Extract text input from the request.
"""
try:
# Directly work with the incoming data dictionary
text_input = data.get("inputs", "")
return text_input
except Exception as e:
raise ValueError(f"Error during preprocessing: {str(e)}")
def inference(self, text_input):
"""
Perform inference using the BERTopic model.
- Combine all sentences into a single document and find shared topics.
"""
try:
# Split text into sentences (assuming one sentence per line)
sentences = text_input.strip().split('\n')
# Combine all sentences into a single document
combined_document = " ".join(sentences)
# Perform topic inference on the combined document
topics, probabilities = self.topic_model.transform([combined_document])
# Prepare the results
results = []
for topic, prob in zip(topics, probabilities):
topic_info = self.topic_model.get_topic(topic)
topic_words = [word for word, _ in topic_info] if topic_info else []
# Get custom label for the topic
if hasattr(self.topic_model, "custom_labels_") and self.topic_model.custom_labels_ is not None:
custom_label = self.topic_model.custom_labels_[topic + 1]
else:
custom_label = f"Topic {topic}" # Fallback label
results.append({
"topic": int(topic),
"probability": float(prob),
"top_words": topic_words[:5], # Top 5 words
"customLabel": custom_label # Add custom label
})
return results
except Exception as e:
raise ValueError(f"Error during inference: {str(e)}")
def postprocess(self, results):
"""
Postprocess the inference results into a JSON-serializable list.
"""
return results # Directly returning the list of results
def __call__(self, data):
"""
Handle the incoming request.
"""
try:
# Preprocess the data
text_input = self.preprocess(data)
# Perform inference
results = self.inference(text_input)
# Postprocess the results
response = self.postprocess(results)
return response
except Exception as e:
return [{"error": str(e)}] # Returning error as a list with a dictionary