Marroco93 commited on
Commit
1d6eb67
1 Parent(s): ecb4879

no message

Browse files
Files changed (1) hide show
  1. main.py +36 -20
main.py CHANGED
@@ -100,33 +100,49 @@ def split_text_by_tokens(text, max_tokens=1024):
100
  tokenizer = AutoTokenizer.from_pretrained("nsi319/legal-pegasus")
101
  model = AutoModelForSeq2SeqLM.from_pretrained("nsi319/legal-pegasus")
102
 
103
- def summarize_legal_text(text):
104
- # Ensure the text is within the maximum length limit for the model
105
- inputs = tokenizer.encode(text, return_tensors='pt', max_length=1024, truncation=True)
106
-
107
- # Generate summary
108
- summary_ids = model.generate(
109
- inputs,
110
- num_beams=5,
111
- no_repeat_ngram_size=3,
112
- length_penalty=1.0,
113
- min_length=150,
114
- max_length=1000,
115
- early_stopping=True
116
- )
117
-
118
- # Decode generated tokens to a string
119
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
120
- return summary
121
-
122
 
123
  class SummarizeRequest(BaseModel):
124
  text: str
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  @app.post("/summarize")
127
  async def summarize_text(request: SummarizeRequest):
128
  try:
129
- # Use the newly defined summarization function
130
  summarized_text = summarize_legal_text(request.text)
131
  return JSONResponse(content={"summary": summarized_text})
132
  except Exception as e:
 
100
  tokenizer = AutoTokenizer.from_pretrained("nsi319/legal-pegasus")
101
  model = AutoModelForSeq2SeqLM.from_pretrained("nsi319/legal-pegasus")
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  class SummarizeRequest(BaseModel):
105
  text: str
106
 
107
+ def chunk_text(text, max_length=1024):
108
+ """Split the text into manageable parts for the model to handle."""
109
+ words = text.split()
110
+ current_chunk = ""
111
+ chunks = []
112
+
113
+ for word in words:
114
+ if len(tokenizer.encode(current_chunk + word)) < max_length:
115
+ current_chunk += word + ' '
116
+ else:
117
+ chunks.append(current_chunk.strip())
118
+ current_chunk = word + ' '
119
+ chunks.append(current_chunk.strip()) # Add the last chunk
120
+ return chunks
121
+
122
+ def summarize_legal_text(text):
123
+ """Generate summaries for each chunk and combine them."""
124
+ chunks = chunk_text(text, max_length=900) # A bit less than 1024 to be safe
125
+ all_summaries = []
126
+
127
+ for chunk in chunks:
128
+ inputs = tokenizer.encode(chunk, return_tensors='pt', max_length=1024, truncation=True)
129
+ summary_ids = model.generate(
130
+ inputs,
131
+ num_beams=5,
132
+ no_repeat_ngram_size=3,
133
+ length_penalty=1.0,
134
+ min_length=150,
135
+ max_length=300, # You can adjust this based on your needs
136
+ early_stopping=True
137
+ )
138
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
139
+ all_summaries.append(summary)
140
+
141
+ return " ".join(all_summaries)
142
+
143
  @app.post("/summarize")
144
  async def summarize_text(request: SummarizeRequest):
145
  try:
 
146
  summarized_text = summarize_legal_text(request.text)
147
  return JSONResponse(content={"summary": summarized_text})
148
  except Exception as e: