Spaces:
Sleeping
Sleeping
File size: 5,301 Bytes
42d64b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from difflib import SequenceMatcher
class EndpointHandler:
def __init__(self, path=""):
# Load model and tokenizer from the current directory or specified path
model_path = path if path else "."
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval() # Set to evaluation mode
def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0):
# Add the grammar correction prefix to each sentence
prefix = "correct grammar for this sentence: "
sentences_with_prefix = [prefix + s for s in sentences]
inputs = self.tokenizer(
sentences_with_prefix,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
outputs = self.model.generate(
**inputs,
max_length=512,
num_beams=5,
temperature=temperature,
num_return_sequences=num_return_sequences,
early_stopping=True
)
decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
if num_return_sequences > 1:
grouped = [
decoded[i * num_return_sequences:(i + 1) * num_return_sequences]
for i in range(len(sentences))
]
return grouped
else:
return decoded
def compute_changes(self, original, enhanced):
# Your existing compute_changes logic
changes = []
matcher = SequenceMatcher(None, original.split(), enhanced.split())
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag in ("replace", "insert", "delete"):
original_phrase = " ".join(original.split()[i1:i2])
new_phrase = " ".join(enhanced.split()[j1:j2])
changes.append({
"original_phrase": original_phrase,
"new_phrase": new_phrase,
"char_start": i1,
"char_end": i2,
"token_start": i1,
"token_end": i2,
"explanation": f"{tag} change",
"error_type": "",
"tip": ""
})
return changes
def __call__(self, inputs):
# This method is the main entry point for the Hugging Face Endpoint.
# Check for both standard and wrapped JSON inputs
if isinstance(inputs, list):
sentences = inputs
parameters = {}
elif isinstance(inputs, dict):
# Check for the common {"inputs": "...", "parameters": {}} format
sentences = inputs.get("inputs", [])
# If inputs is a single string, wrap it in a list
if isinstance(sentences, str):
sentences = [sentences]
parameters = inputs.get("parameters", {})
else:
return {
"success": False,
"error": "Invalid input format. Expected a string, list of strings, or a dictionary with 'inputs' and 'parameters' keys."
}
# Handle optional parameters
num_return_sequences = parameters.get("num_return_sequences", 1)
temperature = parameters.get("temperature", 1.0)
if not sentences:
return {
"success": False,
"error": "No sentences provided."
}
try:
paraphrased = self.paraphrase_batch(sentences, num_return_sequences, temperature)
results = []
if num_return_sequences > 1:
# Logic for multiple return sequences
for i, orig in enumerate(sentences):
for cand in paraphrased[i]:
results.append({
"original_sentence": orig,
"enhanced_sentence": cand,
"changes": self.compute_changes(orig, cand)
})
else:
# Logic for single return sequence
for orig, cand in zip(sentences, paraphrased):
results.append({
"original_sentence": orig,
"enhanced_sentence": cand,
"changes": self.compute_changes(orig, cand)
})
return {
"success": True,
"results": results,
"sentences_count": len(sentences),
"processed_count": len(results),
"skipped_count": 0,
"error_count": 0
}
except Exception as e:
return {
"success": False,
"error": str(e),
"sentences_count": len(sentences),
"processed_count": 0,
"skipped_count": 0,
"error_count": 1
} |