TS / app.py
Agamrampal's picture
2
2607998
import gradio as gr
from fastapi import FastAPI
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from pydantic import BaseModel
# Initialize FastAPI
app = FastAPI()
# Initialize the model with improved configuration
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
summarizer = pipeline(
"summarization",
model=model,
tokenizer=tokenizer,
framework="pt", # Explicitly using PyTorch
)
# Define Pydantic models
class SummarizationRequest(BaseModel):
text: str
max_length: int = 130
min_length: int = 30
do_sample: bool = False
num_beams: int = 4 # Added beam search for better results
early_stopping: bool = True # Stop generation when good candidates found
no_repeat_ngram_size: int = 3 # Avoid repetition
class SummarizationResponse(BaseModel):
summary: str
# Enhanced preprocessing function for improved results
def preprocess_text(text):
# Remove excessive newlines and whitespace
cleaned_text = ' '.join(text.split())
return cleaned_text
# API endpoint
@app.post("/api/summarize", response_model=SummarizationResponse)
async def summarize_text(request: SummarizationRequest):
# Preprocess input text
preprocessed_text = preprocess_text(request.text)
# Generate summary with improved parameters
summary = summarizer(
preprocessed_text,
max_length=request.max_length,
min_length=request.min_length,
do_sample=request.do_sample,
num_beams=request.num_beams,
early_stopping=request.early_stopping,
no_repeat_ngram_size=request.no_repeat_ngram_size
)
summary_text = summary[0]['summary_text']
# Post-process: fix capitalization and trailing periods
if summary_text and not summary_text.endswith('.'):
summary_text += '.'
return SummarizationResponse(summary=summary_text)
# Create the Gradio interface as the frontend
# This will be minimal since you want it to function as a backend API
demo = gr.Interface(
fn=lambda: "API is running at /api/summarize. This is a backend service only.",
inputs=None,
outputs="text",
title="Text Summarization API",
description="This is a backend API service. Send POST requests to /api/summarize"
)
# Mount FastAPI app
gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
# Start the FastAPI app with Gradio mounted
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)