Marroco93 commited on
Commit
d11c4a1
·
1 Parent(s): 27153aa

no message

Browse files
Files changed (1) hide show
  1. main.py +17 -10
main.py CHANGED
@@ -8,7 +8,7 @@ from typing import Generator
8
  import json # Asegúrate de que esta línea esté al principio del archivo
9
  import nltk
10
  import os
11
- from transformers import pipeline
12
 
13
 
14
  nltk.data.path.append(os.getenv('NLTK_DATA'))
@@ -18,6 +18,10 @@ app = FastAPI()
18
  # Initialize the InferenceClient with your model
19
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
20
 
 
 
 
 
21
  # summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
22
 
23
  summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
@@ -78,26 +82,29 @@ async def generate_text(item: Item):
78
  # Stream response back to the client
79
  return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
80
 
81
- def split_text(text, max_size=1000):
82
- # Splits the text into chunks of approximately `max_size` words
83
- words = text.split()
84
- for i in range(0, len(words), max_size):
85
- yield ' '.join(words[i:i+max_size])
 
 
86
 
87
  def summarize_large_text(text):
88
- chunks = list(split_text(text))
 
89
  summaries = [summarizer(chunk, max_length=500, min_length=100, do_sample=False) for chunk in chunks]
90
- combined_summary = ' '.join(sum[0]['summary_text'] for sum in summaries)
91
  return combined_summary
92
 
93
  @app.post("/summarize")
94
  async def summarize_text(request: SummarizeRequest):
95
  try:
96
- # Adjusting summarization for very large texts
97
  summarized_text = summarize_large_text(request.text)
98
  return JSONResponse(content={"summary": summarized_text})
99
  except Exception as e:
100
- # Handle exceptions that could arise during summarization
 
101
  raise HTTPException(status_code=500, detail=str(e))
102
 
103
 
 
8
  import json # Asegúrate de que esta línea esté al principio del archivo
9
  import nltk
10
  import os
11
+ from transformers import pipeline, BartTokenizer
12
 
13
 
14
  nltk.data.path.append(os.getenv('NLTK_DATA'))
 
18
  # Initialize the InferenceClient with your model
19
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
20
 
21
+ # Assuming you've initialized the tokenizer and model for BART
22
+ tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
23
+ summarizer = pipeline("summarization", model="facebook/bart-large")
24
+
25
  # summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
26
 
27
  summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
 
82
  # Stream response back to the client
83
  return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
84
 
85
+
86
+ def split_text_by_tokens(text, max_tokens=1024):
87
+ # Tokenize the text and split into chunks that fit within the model's max token limit
88
+ tokens = tokenizer.tokenize(text)
89
+ for i in range(0, len(tokens), max_tokens):
90
+ # Convert tokens back to string for summarization
91
+ yield tokenizer.convert_tokens_to_string(tokens[i:i+max_tokens])
92
 
93
  def summarize_large_text(text):
94
+ # Use the new split_text_by_tokens to handle large text
95
+ chunks = list(split_text_by_tokens(text))
96
  summaries = [summarizer(chunk, max_length=500, min_length=100, do_sample=False) for chunk in chunks]
97
+ combined_summary = ' '.join(sum['summary_text'] for sum in summaries) # Adjusted to handle potential output structure
98
  return combined_summary
99
 
100
  @app.post("/summarize")
101
  async def summarize_text(request: SummarizeRequest):
102
  try:
 
103
  summarized_text = summarize_large_text(request.text)
104
  return JSONResponse(content={"summary": summarized_text})
105
  except Exception as e:
106
+ # Log the error or provide more detail if needed
107
+ print(f"Error during summarization: {e}")
108
  raise HTTPException(status_code=500, detail=str(e))
109
 
110