import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# 모델 로드
tokenizer = AutoTokenizer.from_pretrained("noahkim/KoT5_news_summarization")
model = AutoModelForSeq2SeqLM.from_pretrained("noahkim/KoT5_news_summarization")

# GPU 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 텍스트 요약 함수
def summarize_text(input_text):
    inputs = tokenizer(input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=2048)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    
    summary_text_ids = model.generate(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        max_length=512,
        min_length=128,
        num_beams=6,
        repetition_penalty=1.5,
        no_repeat_ngram_size=15,
    )
    
    summary_text = tokenizer.decode(summary_text_ids[0], skip_special_tokens=True)
    return summary_text

# Gradio 인터페이스 정의
iface = gr.Interface(
    fn=summarize_text, 
    inputs=gr.Textbox(label="Input Text"), 
    outputs=gr.Textbox(label="Summary")
)

# Space에서 바로 실행할 수 있도록 실행
iface.launch()