Spaces:
Sleeping
Sleeping
no message
Browse files
main.py
CHANGED
@@ -8,7 +8,7 @@ from typing import Generator
|
|
8 |
import json # Asegúrate de que esta línea esté al principio del archivo
|
9 |
import nltk
|
10 |
import os
|
11 |
-
from transformers import pipeline,
|
12 |
|
13 |
|
14 |
nltk.data.path.append(os.getenv('NLTK_DATA'))
|
@@ -18,9 +18,7 @@ app = FastAPI()
|
|
18 |
# Initialize the InferenceClient with your model
|
19 |
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
|
20 |
|
21 |
-
|
22 |
-
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
23 |
-
summarizer = pipeline("summarization", model="facebook/bart-large")
|
24 |
|
25 |
# summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
|
26 |
|
@@ -96,31 +94,43 @@ def split_text_by_tokens(text, max_tokens=1024):
|
|
96 |
print("Tokenization complete.")
|
97 |
return chunks, token_counts
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
for chunk in chunks:
|
105 |
-
print(f"Summarizing chunk: {chunk[:50]}...") # Print the first 50 characters of the chunk
|
106 |
-
# Perform summarization on the chunk
|
107 |
-
summary = summarizer(chunk, max_length=500, min_length=100, do_sample=False)
|
108 |
-
if summary:
|
109 |
-
summaries.append(summary[0]['summary_text'])
|
110 |
-
print(f"Summary: {summary[0]['summary_text'][:50]}") # Print the first 50 characters of the summary
|
111 |
-
combined_summary = ' '.join(summaries)
|
112 |
-
return combined_summary
|
113 |
|
|
|
|
|
114 |
|
115 |
@app.post("/summarize")
|
116 |
async def summarize_text(request: SummarizeRequest):
|
117 |
try:
|
118 |
-
|
|
|
119 |
return JSONResponse(content={"summary": summarized_text})
|
120 |
except Exception as e:
|
121 |
-
print(f"Error during
|
122 |
raise HTTPException(status_code=500, detail=str(e))
|
123 |
|
124 |
-
|
125 |
if __name__ == "__main__":
|
126 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
8 |
import json # Asegúrate de que esta línea esté al principio del archivo
|
9 |
import nltk
|
10 |
import os
|
11 |
+
from transformers import pipeline, AutoModelForSeq2SeqLM
|
12 |
|
13 |
|
14 |
nltk.data.path.append(os.getenv('NLTK_DATA'))
|
|
|
18 |
# Initialize the InferenceClient with your model
|
19 |
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
|
20 |
|
21 |
+
|
|
|
|
|
22 |
|
23 |
# summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
|
24 |
|
|
|
94 |
print("Tokenization complete.")
|
95 |
return chunks, token_counts
|
96 |
|
97 |
+
# Load the tokenizer and model from Hugging Face Hub
|
98 |
+
tokenizer = AutoTokenizer.from_pretrained("nsi319/legal-pegasus")
|
99 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("nsi319/legal-pegasus")
|
100 |
+
|
101 |
+
def summarize_legal_text(text):
|
102 |
+
# Ensure the text is within the maximum length limit for the model
|
103 |
+
inputs = tokenizer.encode(text, return_tensors='pt', max_length=1024, truncation=True)
|
104 |
+
|
105 |
+
# Generate summary
|
106 |
+
summary_ids = model.generate(
|
107 |
+
inputs,
|
108 |
+
num_beams=9,
|
109 |
+
no_repeat_ngram_size=3,
|
110 |
+
length_penalty=2.0,
|
111 |
+
min_length=150,
|
112 |
+
max_length=250,
|
113 |
+
early_stopping=True
|
114 |
+
)
|
115 |
+
|
116 |
+
# Decode generated tokens to a string
|
117 |
+
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
118 |
+
return summary
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
+
class SummarizeRequest(BaseModel):
|
122 |
+
text: str
|
123 |
|
124 |
@app.post("/summarize")
|
125 |
async def summarize_text(request: SummarizeRequest):
|
126 |
try:
|
127 |
+
# Use the newly defined summarization function
|
128 |
+
summarized_text = summarize_legal_text(request.text)
|
129 |
return JSONResponse(content={"summary": summarized_text})
|
130 |
except Exception as e:
|
131 |
+
print(f"Error during summarization: {e}")
|
132 |
raise HTTPException(status_code=500, detail=str(e))
|
133 |
|
134 |
+
|
135 |
if __name__ == "__main__":
|
136 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|