Marroco93 commited on
Commit
c0b9a69
1 Parent(s): 98648e1

no message

Browse files
Files changed (1) hide show
  1. main.py +30 -20
main.py CHANGED
@@ -8,7 +8,7 @@ from typing import Generator
8
  import json # Asegúrate de que esta línea esté al principio del archivo
9
  import nltk
10
  import os
11
- from transformers import pipeline, BartTokenizer
12
 
13
 
14
  nltk.data.path.append(os.getenv('NLTK_DATA'))
@@ -18,9 +18,7 @@ app = FastAPI()
18
  # Initialize the InferenceClient with your model
19
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
20
 
21
- # Assuming you've initialized the tokenizer and model for BART
22
- tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
23
- summarizer = pipeline("summarization", model="facebook/bart-large")
24
 
25
  # summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
26
 
@@ -96,31 +94,43 @@ def split_text_by_tokens(text, max_tokens=1024):
96
  print("Tokenization complete.")
97
  return chunks, token_counts
98
 
99
- def summarize_large_text(text):
100
- chunks, token_counts = split_text_by_tokens(text, max_tokens=1024 - 10) # Slight buffer to avoid edge cases
101
- summaries = []
102
- print("Starting summarization of chunks...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- for chunk in chunks:
105
- print(f"Summarizing chunk: {chunk[:50]}...") # Print the first 50 characters of the chunk
106
- # Perform summarization on the chunk
107
- summary = summarizer(chunk, max_length=500, min_length=100, do_sample=False)
108
- if summary:
109
- summaries.append(summary[0]['summary_text'])
110
- print(f"Summary: {summary[0]['summary_text'][:50]}") # Print the first 50 characters of the summary
111
- combined_summary = ' '.join(summaries)
112
- return combined_summary
113
 
 
 
114
 
115
  @app.post("/summarize")
116
  async def summarize_text(request: SummarizeRequest):
117
  try:
118
- summarized_text = summarize_large_text(request.text)
 
119
  return JSONResponse(content={"summary": summarized_text})
120
  except Exception as e:
121
- print(f"Error during tokenization: {e}")
122
  raise HTTPException(status_code=500, detail=str(e))
123
 
124
-
125
  if __name__ == "__main__":
126
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
8
  import json # Asegúrate de que esta línea esté al principio del archivo
9
  import nltk
10
  import os
11
+ from transformers import pipeline, AutoModelForSeq2SeqLM
12
 
13
 
14
  nltk.data.path.append(os.getenv('NLTK_DATA'))
 
18
  # Initialize the InferenceClient with your model
19
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
20
 
21
+
 
 
22
 
23
  # summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
24
 
 
94
  print("Tokenization complete.")
95
  return chunks, token_counts
96
 
97
+ # Load the tokenizer and model from Hugging Face Hub
98
+ tokenizer = AutoTokenizer.from_pretrained("nsi319/legal-pegasus")
99
+ model = AutoModelForSeq2SeqLM.from_pretrained("nsi319/legal-pegasus")
100
+
101
+ def summarize_legal_text(text):
102
+ # Ensure the text is within the maximum length limit for the model
103
+ inputs = tokenizer.encode(text, return_tensors='pt', max_length=1024, truncation=True)
104
+
105
+ # Generate summary
106
+ summary_ids = model.generate(
107
+ inputs,
108
+ num_beams=9,
109
+ no_repeat_ngram_size=3,
110
+ length_penalty=2.0,
111
+ min_length=150,
112
+ max_length=250,
113
+ early_stopping=True
114
+ )
115
+
116
+ # Decode generated tokens to a string
117
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
118
+ return summary
119
 
 
 
 
 
 
 
 
 
 
120
 
121
+ class SummarizeRequest(BaseModel):
122
+ text: str
123
 
124
  @app.post("/summarize")
125
  async def summarize_text(request: SummarizeRequest):
126
  try:
127
+ # Use the newly defined summarization function
128
+ summarized_text = summarize_legal_text(request.text)
129
  return JSONResponse(content={"summary": summarized_text})
130
  except Exception as e:
131
+ print(f"Error during summarization: {e}")
132
  raise HTTPException(status_code=500, detail=str(e))
133
 
134
+
135
  if __name__ == "__main__":
136
  uvicorn.run(app, host="0.0.0.0", port=8000)