Spaces:
Running
Running
File size: 1,579 Bytes
a223a72 d5ba0c8 1a14884 49e02c9 a223a72 719136f 49e02c9 8244eb6 49e02c9 aab3d8a 49e02c9 d5ba0c8 49e02c9 d5ba0c8 252c61c d5ba0c8 4793826 49e02c9 d5ba0c8 49e02c9 4793826 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import os
import re
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class InputText(BaseModel):
text: str
def clean_text(text: str) -> str:
text = re.sub(r"[\r\n\t]+", " ", text)
text = re.sub(r"\s{2,}", " ", text)
text = text.strip()
return text
@app.post("/summarize")
async def summarize(input: InputText):
cleaned_input = clean_text(input.text)
prompt = f"Summarize the following and format it in HTML using <p>, <ul>, <li>, and <strong> where appropriate. Preserve headings too:\n\n: {cleaned_input}"
temp_tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
input_length = len(temp_tokenizer(prompt)["input_ids"])
if input_length < 512:
model_name = "sshleifer/distilbart-cnn-12-6"
else:
model_name = "pszemraj/led-large-book-summary"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=16384,
truncation=True,
)
summary_ids = model.generate(
inputs["input_ids"],
max_length=1024,
min_length=50,
length_penalty=2.0,
num_beams=4,
early_stopping=True,
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
|