File size: 3,016 Bytes
d85085f 2f59689 d85085f 0ae5234 d85085f 0ae5234 d85085f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import json
from bertopic import BERTopic
class EndpointHandler:
def __init__(self, model_path="SCANSKY/BERTopic-Tourism-Chinese"):
"""
Initialize the handler. Load the BERTopic model from Hugging Face.
"""
self.topic_model = BERTopic.load(model_path)
def preprocess(self, data):
"""
Preprocess the incoming request data.
- Extract text input from the request.
"""
try:
# Directly work with the incoming data dictionary
text_input = data.get("inputs", "")
return text_input
except Exception as e:
raise ValueError(f"Error during preprocessing: {str(e)}")
def inference(self, text_input):
"""
Perform inference using the BERTopic model.
- Combine all sentences into a single document and find shared topics.
"""
try:
# Split text into sentences (assuming one sentence per line)
sentences = text_input.strip().split('\n')
# Combine all sentences into a single document
combined_document = " ".join(sentences)
# Perform topic inference on the combined document
topics, probabilities = self.topic_model.transform([combined_document])
# Prepare the results
results = []
for topic, prob in zip(topics, probabilities):
topic_info = self.topic_model.get_topic(topic)
topic_words = [word for word, _ in topic_info] if topic_info else []
# Get custom label for the topic
if hasattr(self.topic_model, "custom_labels_") and self.topic_model.custom_labels_ is not None:
custom_label = self.topic_model.custom_labels_[topic + 1]
else:
custom_label = f"Topic {topic}" # Fallback label
results.append({
"topic": int(topic),
"probability": float(prob),
"top_words": topic_words[:5], # Top 5 words
"customLabel": custom_label # Add custom label
})
return results
except Exception as e:
raise ValueError(f"Error during inference: {str(e)}")
def postprocess(self, results):
"""
Postprocess the inference results into a JSON-serializable list.
"""
return results # Directly returning the list of results
def __call__(self, data):
"""
Handle the incoming request.
"""
try:
# Preprocess the data
text_input = self.preprocess(data)
# Perform inference
results = self.inference(text_input)
# Postprocess the results
response = self.postprocess(results)
return response
except Exception as e:
return [{"error": str(e)}] # Returning error as a list with a dictionary
|