DuckyBlender's picture
fixed cpu inference?
a755228
import gradio as gr
from youtube_transcript_api import YouTubeTranscriptApi
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import re
import os
import torch
# import dotenv
# dotenv.load_dotenv()
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
# Install the Flash Attention library
print("Installing Flash Attention library...")
import subprocess
subprocess.run(
"pip install flash_attn --no-build-isolation --break-system-packages",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
print("Flash Attention library installed")
# Configure 4-bit quantization for model loading
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
else:
device = torch.device("cpu")
bnb_config = None
print("Using CPU")
# Uncomment and set your Hugging Face token if needed
token = os.environ["HF_TOKEN"]
# Load the Phi-3 model and tokenizer
print("Loading model and tokenizer...")
model_id = "microsoft/Phi-3-mini-128k-instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
trust_remote_code=True,
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Define the system prompt and generation pipeline
system_prompt = "Summarize this YouTube video. Give a brief summary of the video content with the key points and main takeaways."
messages = [{"role": "system", "content": system_prompt}]
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
generation_args = {
"max_new_tokens": 32767,
"return_full_text": False,
"do_sample": True,
"temperature": 0.2,
}
# Function to extract the video ID from a YouTube URL
def extract_video_id(url):
video_id_match = re.search(r"(?:v=|\/)([0-9A-Za-z_-]{11}).*", url)
if video_id_match:
print(f"Extracted video ID: {video_id_match.group(1)}")
return video_id_match.group(1)
return None
# Function to get the transcript of a YouTube video
def get_transcript(video_id):
try:
transcript = YouTubeTranscriptApi.get_transcript(video_id)
transcription = [entry['text'] for entry in transcript]
print(f"Transcript: {transcription}")
return " ".join(transcription)
except Exception as e:
return f"Error fetching transcript: {str(e)}"
# Function to summarize the text using the model
def summarize_text(text):
messages.append({"role": "user", "content": text})
output = pipe(messages, **generation_args)
output = output[0]['generated_text'].strip() # type: ignore
print(f"Summary: {output}")
return output
# Main function to process the video URL
def process_video(url):
video_id = extract_video_id(url)
if not video_id:
print("Invalid YouTube URL")
return "Invalid YouTube URL"
transcript = get_transcript(video_id)
if transcript.startswith("Error"):
return transcript
summary = summarize_text(transcript)
return summary, transcript
# Function to update the embedded video player
def update_embed(url):
video_id = extract_video_id(url)
if video_id:
embed_url = f"https://www.youtube.com/embed/{video_id}"
return f"<div class='gradio-embed-container'><iframe class='gradio-embed' src='{embed_url}' frameborder='0' allowfullscreen></iframe></div>"
return "<div class='gradio-embed-container'><iframe class='gradio-embed' src='' frameborder='0' allowfullscreen></iframe></div>"
# Gradio UI setup
with gr.Blocks(css="""
.gradio-embed-container { position: relative; width: 100%; padding-bottom: 56.25%; height: 0; }
.gradio-embed { position: absolute; top: 0; left: 0; width: 100%; height: 100%; }
.small-font { font-size: 0.6em; }
""") as demo:
gr.Markdown("""
# YouTube Video Summarizer using Phi-3-mini-128k-instruct
Summarize any YouTube video using the Phi-3-mini-128k-instruct model.
""")
with gr.Row():
with gr.Column(scale=1):
url = gr.Textbox(
label="YouTube URL",
placeholder="https://www.youtube.com/watch?v=dQw4w9WgXcQ",
max_lines=1
)
summary = gr.Textbox(
label="Summary",
placeholder="Summary will appear here...",
lines=10,
show_label=True,
show_copy_button=True,
elem_classes="small-font"
)
transcript = gr.Textbox(
label="Transcript",
placeholder="Transcript will appear here...",
lines=1,
show_label=True,
show_copy_button=True,
elem_classes="small-font"
)
btn = gr.Button("Summarize")
btn.click(fn=process_video, inputs=url, outputs=[summary, transcript])
with gr.Column(scale=1):
video_embed = gr.HTML("<div class='gradio-embed-container'><iframe class='gradio-embed' src='' frameborder='0' allowfullscreen></iframe></div>")
url.change(fn=update_embed, inputs=url, outputs=video_embed)
demo.launch()