lecture-summaries / gradio_nlp_group_project.py
cranonieu2021's picture
Update gradio_nlp_group_project.py
3331caa verified
# -*- coding: utf-8 -*-
"""Gradio NLP Group Project.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1hDGMwj7G7avlxrqmXe6SIN9LjLRRsuqE
"""
import gradio as gr
from youtube_transcript_api import YouTubeTranscriptApi
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianMTModel, MarianTokenizer, AutoModelForSequenceClassification
import torch
class TextProcessor:
def __init__(self, text):
self.text = text
def summarize_text(self, text):
tokenizer = AutoTokenizer.from_pretrained('cranonieu2021/pegasus-on-lectures')
model = AutoModelForSeq2SeqLM.from_pretrained("cranonieu2021/pegasus-on-lectures")
inputs = tokenizer(text, max_length=1024, return_tensors="pt", truncation=True)
summary_ids = model.generate(inputs.input_ids, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
def translate_text(self, text):
model_name = "sfarjebespalaia/enestranslatorforsummaries"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
src_text = [text]
tokenized_text = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
translated = model.generate(**tokenized_text)
translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
return translated_text
def classify_text(self, text):
model_name = "gserafico/roberta-base-finetuned-classifier-roberta1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()
labels = {
0: 'Social Sciences',
1: 'Arts',
2: 'Natural Sciences',
3: 'Business and Law',
4: 'Engineering and Technology'
}
return labels[predicted_class_idx]
def get_transcript(video_id):
try:
transcripts = YouTubeTranscriptApi.list_transcripts(video_id)
except NoTranscriptFound:
return "No transcript found for this video.", 'en'
except TranscriptsDisabled:
return "Transcripts are disabled for this video.", 'en'
except Exception as e:
return f"An error occurred: {str(e)}", 'en'
available_languages = []
for transcript in transcripts:
language_details = {
'Language': transcript.language,
'Language Code': transcript.language_code,
'Is Generated': transcript.is_generated,
'Is Translatable': transcript.is_translatable
}
available_languages.append(language_details)
available_languages = [transcript['Language Code'] for transcript in available_languages if transcript['Language Code'] in ['en']]
if 'en' in available_languages:
transcript_list = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
transcript_text = ' '.join([i['text'] for i in transcript_list])
return transcript_text, 'en'
else:
return 'Transcript in unsupported language.'
def process_text(video_id):
transcript, language = get_transcript(video_id)
if transcript.startswith("An error occurred:") or "No transcript" in transcript:
return {"Error": transcript, "Language Detected": "None"}
processor = TextProcessor(transcript)
try:
summarized_text = processor.summarize_text(transcript)
translated_text = processor.translate_text(summarized_text)
classification_result = processor.classify_text(summarized_text)
results = {
'Language Detected': language,
'Summarized Text': summarized_text,
'Translated Text': translated_text,
'Classification Result': classification_result
}
except Exception as e:
results = {'Error': f"An error occurred during processing: {str(e)}"}
return results
iface = gr.Interface(
fn=process_text,
inputs=[gr.Textbox(label="YouTube Video ID")],
outputs=[gr.JSON(label="Results")],
title="Text Processing App with YouTube Transcript",
description="This app allows you to fetch, summarize, translate, and classify YouTube video transcripts. Errors are handled and displayed."
)
def main():
iface.launch()
if __name__ == '__main__':
main()