Spaces:
Sleeping
Sleeping
no message
Browse files
main.py
CHANGED
@@ -100,33 +100,49 @@ def split_text_by_tokens(text, max_tokens=1024):
|
|
100 |
tokenizer = AutoTokenizer.from_pretrained("nsi319/legal-pegasus")
|
101 |
model = AutoModelForSeq2SeqLM.from_pretrained("nsi319/legal-pegasus")
|
102 |
|
103 |
-
def summarize_legal_text(text):
|
104 |
-
# Ensure the text is within the maximum length limit for the model
|
105 |
-
inputs = tokenizer.encode(text, return_tensors='pt', max_length=1024, truncation=True)
|
106 |
-
|
107 |
-
# Generate summary
|
108 |
-
summary_ids = model.generate(
|
109 |
-
inputs,
|
110 |
-
num_beams=5,
|
111 |
-
no_repeat_ngram_size=3,
|
112 |
-
length_penalty=1.0,
|
113 |
-
min_length=150,
|
114 |
-
max_length=1000,
|
115 |
-
early_stopping=True
|
116 |
-
)
|
117 |
-
|
118 |
-
# Decode generated tokens to a string
|
119 |
-
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
120 |
-
return summary
|
121 |
-
|
122 |
|
123 |
class SummarizeRequest(BaseModel):
|
124 |
text: str
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
@app.post("/summarize")
|
127 |
async def summarize_text(request: SummarizeRequest):
|
128 |
try:
|
129 |
-
# Use the newly defined summarization function
|
130 |
summarized_text = summarize_legal_text(request.text)
|
131 |
return JSONResponse(content={"summary": summarized_text})
|
132 |
except Exception as e:
|
|
|
100 |
tokenizer = AutoTokenizer.from_pretrained("nsi319/legal-pegasus")
|
101 |
model = AutoModelForSeq2SeqLM.from_pretrained("nsi319/legal-pegasus")
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
class SummarizeRequest(BaseModel):
|
105 |
text: str
|
106 |
|
107 |
+
def chunk_text(text, max_length=1024):
|
108 |
+
"""Split the text into manageable parts for the model to handle."""
|
109 |
+
words = text.split()
|
110 |
+
current_chunk = ""
|
111 |
+
chunks = []
|
112 |
+
|
113 |
+
for word in words:
|
114 |
+
if len(tokenizer.encode(current_chunk + word)) < max_length:
|
115 |
+
current_chunk += word + ' '
|
116 |
+
else:
|
117 |
+
chunks.append(current_chunk.strip())
|
118 |
+
current_chunk = word + ' '
|
119 |
+
chunks.append(current_chunk.strip()) # Add the last chunk
|
120 |
+
return chunks
|
121 |
+
|
122 |
+
def summarize_legal_text(text):
|
123 |
+
"""Generate summaries for each chunk and combine them."""
|
124 |
+
chunks = chunk_text(text, max_length=900) # A bit less than 1024 to be safe
|
125 |
+
all_summaries = []
|
126 |
+
|
127 |
+
for chunk in chunks:
|
128 |
+
inputs = tokenizer.encode(chunk, return_tensors='pt', max_length=1024, truncation=True)
|
129 |
+
summary_ids = model.generate(
|
130 |
+
inputs,
|
131 |
+
num_beams=5,
|
132 |
+
no_repeat_ngram_size=3,
|
133 |
+
length_penalty=1.0,
|
134 |
+
min_length=150,
|
135 |
+
max_length=300, # You can adjust this based on your needs
|
136 |
+
early_stopping=True
|
137 |
+
)
|
138 |
+
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
139 |
+
all_summaries.append(summary)
|
140 |
+
|
141 |
+
return " ".join(all_summaries)
|
142 |
+
|
143 |
@app.post("/summarize")
|
144 |
async def summarize_text(request: SummarizeRequest):
|
145 |
try:
|
|
|
146 |
summarized_text = summarize_legal_text(request.text)
|
147 |
return JSONResponse(content={"summary": summarized_text})
|
148 |
except Exception as e:
|