File size: 1,330 Bytes
0b2c02e
 
 
 
 
 
 
 
 
 
 
 
 
ba78910
 
0b2c02e
ba78910
0b2c02e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import streamlit as st

def generate_blog(title, model_name='gpt2', max_length=500):
    # Check if a GPU is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    st.write(f"Using device: {device}")

    # Load the tokenizer and model
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name).to(device)

    prompt = f"Write a blog post based on this Title: {title}"

    # Prepare the input
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

    # Generate text
    output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, no_repeat_ngram_size=2, early_stopping=True)

    # Decode the generated text
    blog_post = tokenizer.decode(output[0], skip_special_tokens=True)

    return blog_post

st.title("AI Blog Writer")
st.write("Enter a blog title, and the AI will generate a blog post for you!")

title = st.text_input("Enter the blog title:")

if st.button("Generate Blog"):
    if title:
        with st.spinner("Generating blog post..."):
            blog_post = generate_blog(title)
        st.subheader("Generated Blog Post")
        st.write(blog_post)
    else:
        st.warning("Please enter a blog title.")