File size: 5,301 Bytes
42d64b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from difflib import SequenceMatcher

class EndpointHandler:
    def __init__(self, path=""):
        # Load model and tokenizer from the current directory or specified path
        model_path = path if path else "."
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()  # Set to evaluation mode

    def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0):
        # Add the grammar correction prefix to each sentence
        prefix = "correct grammar for this sentence: "
        sentences_with_prefix = [prefix + s for s in sentences]
        
        inputs = self.tokenizer(
            sentences_with_prefix,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        ).to(self.device)
        
        outputs = self.model.generate(
            **inputs,
            max_length=512,
            num_beams=5,
            temperature=temperature,
            num_return_sequences=num_return_sequences,
            early_stopping=True
        )
        
        decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        if num_return_sequences > 1:
            grouped = [
                decoded[i * num_return_sequences:(i + 1) * num_return_sequences]
                for i in range(len(sentences))
            ]
            return grouped
        else:
            return decoded
            
    def compute_changes(self, original, enhanced):
        # Your existing compute_changes logic
        changes = []
        matcher = SequenceMatcher(None, original.split(), enhanced.split())
        for tag, i1, i2, j1, j2 in matcher.get_opcodes():
            if tag in ("replace", "insert", "delete"):
                original_phrase = " ".join(original.split()[i1:i2])
                new_phrase = " ".join(enhanced.split()[j1:j2])
                changes.append({
                    "original_phrase": original_phrase,
                    "new_phrase": new_phrase,
                    "char_start": i1,
                    "char_end": i2,
                    "token_start": i1,
                    "token_end": i2,
                    "explanation": f"{tag} change",
                    "error_type": "",
                    "tip": ""
                })
        return changes

    def __call__(self, inputs):
        # This method is the main entry point for the Hugging Face Endpoint.
        
        # Check for both standard and wrapped JSON inputs
        if isinstance(inputs, list):
            sentences = inputs
            parameters = {}
        elif isinstance(inputs, dict):
            # Check for the common {"inputs": "...", "parameters": {}} format
            sentences = inputs.get("inputs", [])
            # If inputs is a single string, wrap it in a list
            if isinstance(sentences, str):
                sentences = [sentences]
            parameters = inputs.get("parameters", {})
        else:
            return {
                "success": False,
                "error": "Invalid input format. Expected a string, list of strings, or a dictionary with 'inputs' and 'parameters' keys."
            }

        # Handle optional parameters
        num_return_sequences = parameters.get("num_return_sequences", 1)
        temperature = parameters.get("temperature", 1.0)

        if not sentences:
            return {
                "success": False,
                "error": "No sentences provided."
            }

        try:
            paraphrased = self.paraphrase_batch(sentences, num_return_sequences, temperature)
            results = []
            
            if num_return_sequences > 1:
                # Logic for multiple return sequences
                for i, orig in enumerate(sentences):
                    for cand in paraphrased[i]:
                        results.append({
                            "original_sentence": orig,
                            "enhanced_sentence": cand,
                            "changes": self.compute_changes(orig, cand)
                        })
            else:
                # Logic for single return sequence
                for orig, cand in zip(sentences, paraphrased):
                    results.append({
                        "original_sentence": orig,
                        "enhanced_sentence": cand,
                        "changes": self.compute_changes(orig, cand)
                    })
            
            return {
                "success": True,
                "results": results,
                "sentences_count": len(sentences),
                "processed_count": len(results),
                "skipped_count": 0,
                "error_count": 0
            }
        except Exception as e:
            return {
                "success": False,
                "error": str(e),
                "sentences_count": len(sentences),
                "processed_count": 0,
                "skipped_count": 0,
                "error_count": 1
            }