File size: 3,016 Bytes
d85085f
 
 
 
2f59689
d85085f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ae5234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d85085f
0ae5234
 
d85085f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import json
from bertopic import BERTopic

class EndpointHandler:
    def __init__(self, model_path="SCANSKY/BERTopic-Tourism-Chinese"):
        """
        Initialize the handler. Load the BERTopic model from Hugging Face.
        """
        self.topic_model = BERTopic.load(model_path)

    def preprocess(self, data):
        """
        Preprocess the incoming request data.
        - Extract text input from the request.
        """
        try:
            # Directly work with the incoming data dictionary
            text_input = data.get("inputs", "")
            return text_input
        except Exception as e:
            raise ValueError(f"Error during preprocessing: {str(e)}")

    def inference(self, text_input):
        """
        Perform inference using the BERTopic model.
        - Combine all sentences into a single document and find shared topics.
        """
        try:
            # Split text into sentences (assuming one sentence per line)
            sentences = text_input.strip().split('\n')
    
            # Combine all sentences into a single document
            combined_document = " ".join(sentences)
    
            # Perform topic inference on the combined document
            topics, probabilities = self.topic_model.transform([combined_document])
    
            # Prepare the results
            results = []
            for topic, prob in zip(topics, probabilities):
                topic_info = self.topic_model.get_topic(topic)
                topic_words = [word for word, _ in topic_info] if topic_info else []
    
                # Get custom label for the topic
                if hasattr(self.topic_model, "custom_labels_") and self.topic_model.custom_labels_ is not None:
                    custom_label = self.topic_model.custom_labels_[topic + 1]
                else:
                    custom_label = f"Topic {topic}"  # Fallback label
    
                results.append({
                    "topic": int(topic),
                    "probability": float(prob),
                    "top_words": topic_words[:5],  # Top 5 words
                    "customLabel": custom_label  # Add custom label
                })
    
            return results
        except Exception as e:
            raise ValueError(f"Error during inference: {str(e)}")

    
    def postprocess(self, results):
        """
        Postprocess the inference results into a JSON-serializable list.
        """
        return results  # Directly returning the list of results

    def __call__(self, data):
        """
        Handle the incoming request.
        """
        try:
            # Preprocess the data
            text_input = self.preprocess(data)

            # Perform inference
            results = self.inference(text_input)

            # Postprocess the results
            response = self.postprocess(results)

            return response
        except Exception as e:
            return [{"error": str(e)}]  # Returning error as a list with a dictionary