Spaces:
Sleeping
Sleeping
import gradio as gr | |
from fastapi import FastAPI | |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
from pydantic import BaseModel | |
# Initialize FastAPI | |
app = FastAPI() | |
# Initialize the model with improved configuration | |
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") | |
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") | |
summarizer = pipeline( | |
"summarization", | |
model=model, | |
tokenizer=tokenizer, | |
framework="pt", # Explicitly using PyTorch | |
) | |
# Define Pydantic models | |
class SummarizationRequest(BaseModel): | |
text: str | |
max_length: int = 130 | |
min_length: int = 30 | |
do_sample: bool = False | |
num_beams: int = 4 # Added beam search for better results | |
early_stopping: bool = True # Stop generation when good candidates found | |
no_repeat_ngram_size: int = 3 # Avoid repetition | |
class SummarizationResponse(BaseModel): | |
summary: str | |
# Enhanced preprocessing function for improved results | |
def preprocess_text(text): | |
# Remove excessive newlines and whitespace | |
cleaned_text = ' '.join(text.split()) | |
return cleaned_text | |
# API endpoint | |
async def summarize_text(request: SummarizationRequest): | |
# Preprocess input text | |
preprocessed_text = preprocess_text(request.text) | |
# Generate summary with improved parameters | |
summary = summarizer( | |
preprocessed_text, | |
max_length=request.max_length, | |
min_length=request.min_length, | |
do_sample=request.do_sample, | |
num_beams=request.num_beams, | |
early_stopping=request.early_stopping, | |
no_repeat_ngram_size=request.no_repeat_ngram_size | |
) | |
summary_text = summary[0]['summary_text'] | |
# Post-process: fix capitalization and trailing periods | |
if summary_text and not summary_text.endswith('.'): | |
summary_text += '.' | |
return SummarizationResponse(summary=summary_text) | |
# Create the Gradio interface as the frontend | |
# This will be minimal since you want it to function as a backend API | |
demo = gr.Interface( | |
fn=lambda: "API is running at /api/summarize. This is a backend service only.", | |
inputs=None, | |
outputs="text", | |
title="Text Summarization API", | |
description="This is a backend API service. Send POST requests to /api/summarize" | |
) | |
# Mount FastAPI app | |
gr.mount_gradio_app(app, demo, path="/") | |
if __name__ == "__main__": | |
# Start the FastAPI app with Gradio mounted | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |