Marroco93 commited on
Commit
b95f5d7
1 Parent(s): 021d564

no message

Browse files
Files changed (1) hide show
  1. main.py +33 -7
main.py CHANGED
@@ -112,33 +112,59 @@ def segment_text(text: str, max_tokens=500): # Slightly less than 512 for safet
112
  # Use spaCy to divide the document into sentences
113
  doc = nlp(text)
114
  sentences = [sent.text.strip() for sent in doc.sents]
115
-
116
  segments = []
117
  current_segment = []
118
  current_length = 0
119
 
120
  for sentence in sentences:
121
- sentence_length = len(sentence.split()) # Simple word count
 
 
 
 
 
 
 
 
122
  if current_length + sentence_length > max_tokens:
123
- if current_segment: # Make sure there's something to add
124
- segments.append(' '.join(current_segment))
125
  current_segment = [sentence]
126
  current_length = sentence_length
127
  else:
128
  current_segment.append(sentence)
129
  current_length += sentence_length
130
 
131
- # Add the last segment if any
132
- if current_segment:
133
  segments.append(' '.join(current_segment))
134
 
135
  return segments
136
 
 
 
 
 
 
 
 
 
 
137
 
138
  classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
139
 
140
  def classify_segments(segments):
141
- return [classifier(segment) for segment in segments]
 
 
 
 
 
 
 
 
 
 
 
142
 
143
 
144
  @app.post("/summarize")
 
112
  # Use spaCy to divide the document into sentences
113
  doc = nlp(text)
114
  sentences = [sent.text.strip() for sent in doc.sents]
115
+
116
  segments = []
117
  current_segment = []
118
  current_length = 0
119
 
120
  for sentence in sentences:
121
+ sentence_words = sentence.split()
122
+ sentence_length = len(sentence_words)
123
+
124
+ # If sentence exceeds max_tokens, split it further
125
+ if sentence_length > max_tokens:
126
+ parts = split_into_parts(sentence, max_tokens)
127
+ segments.extend(parts) # Add split parts directly to segments
128
+ continue
129
+
130
  if current_length + sentence_length > max_tokens:
131
+ segments.append(' '.join(current_segment))
 
132
  current_segment = [sentence]
133
  current_length = sentence_length
134
  else:
135
  current_segment.append(sentence)
136
  current_length += sentence_length
137
 
138
+ if current_segment: # Add the last segment if any
 
139
  segments.append(' '.join(current_segment))
140
 
141
  return segments
142
 
143
+ def split_into_parts(text, max_tokens):
144
+ words = text.split()
145
+ parts = []
146
+ for i in range(0, len(words), max_tokens):
147
+ part = " ".join(words[i:i + max_tokens])
148
+ parts.append(part)
149
+ return parts
150
+
151
+
152
 
153
  classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
154
 
155
  def classify_segments(segments):
156
+ results = []
157
+ for segment in segments:
158
+ try:
159
+ if len(segment.split()) <= 512: # Ensure segment is within the limit
160
+ result = classifier(segment)
161
+ results.append(result)
162
+ else:
163
+ results.append({"error": f"Segment too long: {len(segment.split())} tokens"})
164
+ except Exception as e:
165
+ results.append({"error": str(e)})
166
+ return results
167
+
168
 
169
 
170
  @app.post("/summarize")