Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| from difflib import SequenceMatcher | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| # Load model and tokenizer from the current directory or specified path | |
| model_path = path if path else "." | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model.to(self.device) | |
| self.model.eval() # Set to evaluation mode | |
| def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0): | |
| # Add the grammar correction prefix to each sentence | |
| prefix = "correct grammar for this sentence: " | |
| sentences_with_prefix = [prefix + s for s in sentences] | |
| inputs = self.tokenizer( | |
| sentences_with_prefix, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=512, | |
| num_beams=5, | |
| temperature=temperature, | |
| num_return_sequences=num_return_sequences, | |
| early_stopping=True | |
| ) | |
| decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| if num_return_sequences > 1: | |
| grouped = [ | |
| decoded[i * num_return_sequences:(i + 1) * num_return_sequences] | |
| for i in range(len(sentences)) | |
| ] | |
| return grouped | |
| else: | |
| return decoded | |
| def compute_changes(self, original, enhanced): | |
| # Your existing compute_changes logic | |
| changes = [] | |
| matcher = SequenceMatcher(None, original.split(), enhanced.split()) | |
| for tag, i1, i2, j1, j2 in matcher.get_opcodes(): | |
| if tag in ("replace", "insert", "delete"): | |
| original_phrase = " ".join(original.split()[i1:i2]) | |
| new_phrase = " ".join(enhanced.split()[j1:j2]) | |
| changes.append({ | |
| "original_phrase": original_phrase, | |
| "new_phrase": new_phrase, | |
| "char_start": i1, | |
| "char_end": i2, | |
| "token_start": i1, | |
| "token_end": i2, | |
| "explanation": f"{tag} change", | |
| "error_type": "", | |
| "tip": "" | |
| }) | |
| return changes | |
| def __call__(self, inputs): | |
| # This method is the main entry point for the Hugging Face Endpoint. | |
| # Check for both standard and wrapped JSON inputs | |
| if isinstance(inputs, list): | |
| sentences = inputs | |
| parameters = {} | |
| elif isinstance(inputs, dict): | |
| # Check for the common {"inputs": "...", "parameters": {}} format | |
| sentences = inputs.get("inputs", []) | |
| # If inputs is a single string, wrap it in a list | |
| if isinstance(sentences, str): | |
| sentences = [sentences] | |
| parameters = inputs.get("parameters", {}) | |
| else: | |
| return { | |
| "success": False, | |
| "error": "Invalid input format. Expected a string, list of strings, or a dictionary with 'inputs' and 'parameters' keys." | |
| } | |
| # Handle optional parameters | |
| num_return_sequences = parameters.get("num_return_sequences", 1) | |
| temperature = parameters.get("temperature", 1.0) | |
| if not sentences: | |
| return { | |
| "success": False, | |
| "error": "No sentences provided." | |
| } | |
| try: | |
| paraphrased = self.paraphrase_batch(sentences, num_return_sequences, temperature) | |
| results = [] | |
| if num_return_sequences > 1: | |
| # Logic for multiple return sequences | |
| for i, orig in enumerate(sentences): | |
| for cand in paraphrased[i]: | |
| results.append({ | |
| "original_sentence": orig, | |
| "enhanced_sentence": cand, | |
| "changes": self.compute_changes(orig, cand) | |
| }) | |
| else: | |
| # Logic for single return sequence | |
| for orig, cand in zip(sentences, paraphrased): | |
| results.append({ | |
| "original_sentence": orig, | |
| "enhanced_sentence": cand, | |
| "changes": self.compute_changes(orig, cand) | |
| }) | |
| return { | |
| "success": True, | |
| "results": results, | |
| "sentences_count": len(sentences), | |
| "processed_count": len(results), | |
| "skipped_count": 0, | |
| "error_count": 0 | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "sentences_count": len(sentences), | |
| "processed_count": 0, | |
| "skipped_count": 0, | |
| "error_count": 1 | |
| } |