t5-small_cnn / main.py
Curative's picture
Update main.py
3e6721b verified
raw
history blame
3.12 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import gradio as gr
import threading
import uvicorn
# Load your fine-tuned model
model_path = "./t5-summarizer" # Path inside Docker container
model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = T5Tokenizer.from_pretrained(model_path, legacy=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# FastAPI app
app = FastAPI()
class TextInput(BaseModel):
text: str
@app.post("/summarize/")
def summarize_text(input: TextInput):
input_text = "summarize: " + input.text.strip().replace("\n", " ")
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {"summary": summary}
# Summarization function for Gradio
def summarize_ui(text):
input_text = "summarize: " + text.strip().replace("\n", " ")
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
# Gradio interface with example texts
gradio_app = gr.Interface(
fn=summarize_ui,
inputs=gr.Textbox(lines=10, placeholder="Paste your text here..."),
outputs=gr.Textbox(label="Summary"),
title="Text Summarizer",
description="Paste your long text and get a concise summary using a fine-tuned T5 model.",
examples=[
["Scientists have recently discovered a new species of frog in the Amazon rainforest. This frog is notable for its bright blue legs and unique mating call, which sounds like a series of short whistles. Researchers believe that the discovery of this species could shed new light on the ecological diversity of the region."],
["The global economy is expected to grow at a slower pace this year, according to new forecasts released today. Economists point to ongoing geopolitical tensions, supply chain disruptions, and inflationary pressures as key factors contributing to the reduced growth outlook."],
["In a thrilling final match, the underdog team scored a last-minute goal to secure their first championship title. Fans erupted into celebration as the team lifted the trophy, marking a historic moment in the club's history."]
],
flagging=False # Disable flagging to prevent permission issue
)
# Function to run Gradio in a thread
def run_gradio():
gradio_app.launch(server_name="0.0.0.0", server_port=7860, share=False)
# Run Gradio in a separate thread
threading.Thread(target=run_gradio).start()
# Run FastAPI with Uvicorn if needed (for local dev)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)