|
import gradio as gr |
|
import torch |
|
from transformers import pipeline, set_seed |
|
import re |
|
import random |
|
import os |
|
|
|
|
|
def initialize_model(): |
|
"""Initialize the text generation model with fallback options""" |
|
try: |
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
print(f"Using device: {'GPU' if device == 0 else 'CPU'}") |
|
|
|
|
|
try: |
|
generator = pipeline( |
|
"text-generation", |
|
model="gpt2-medium", |
|
device=device, |
|
pad_token_id=50256, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
) |
|
print("Loaded GPT2-medium model successfully") |
|
except Exception as e: |
|
print(f"GPT2-medium failed, falling back to GPT2: {e}") |
|
generator = pipeline( |
|
"text-generation", |
|
model="gpt2", |
|
device=device, |
|
pad_token_id=50256 |
|
) |
|
print("Loaded GPT2 model successfully") |
|
|
|
return generator |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
|
|
return None |
|
|
|
|
|
generator = initialize_model() |
|
|
|
def is_valid_caption(caption): |
|
"""Check if caption is valid and makes sense""" |
|
if not caption or len(caption.strip()) < 3: |
|
return False |
|
|
|
words = caption.split() |
|
|
|
|
|
invalid_patterns = [ |
|
r'http[s]?://', |
|
r'www\.', |
|
r'\.com', |
|
r'\.org', |
|
r'\.net', |
|
r'\d{3,}', |
|
r'[^\w\s]', |
|
] |
|
|
|
for pattern in invalid_patterns: |
|
if re.search(pattern, caption, re.IGNORECASE): |
|
return False |
|
|
|
|
|
incomplete_endings = [ |
|
'i get', 'i am', 'i will', 'i have', 'i do', 'i can', |
|
'you get', 'you are', 'you will', 'you have', 'you do', |
|
'we get', 'we are', 'we will', 'we have', 'we do', |
|
'they get', 'they are', 'they will', 'they have', |
|
'there is a', 'there are', 'this is a', 'that is', |
|
'it is a', 'here is', 'when you', 'if you', |
|
'the best', 'the most', 'the only', 'the first' |
|
] |
|
|
|
caption_lower = caption.lower() |
|
for ending in incomplete_endings: |
|
if caption_lower.endswith(ending): |
|
return False |
|
|
|
|
|
nonsensical_words = [ |
|
'lorem', 'ipsum', 'dolor', 'amet', 'consectetur', |
|
'adipiscing', 'elit', 'sed', 'eiusmod', 'tempor' |
|
] |
|
|
|
for word in nonsensical_words: |
|
if word in caption_lower: |
|
return False |
|
|
|
|
|
if len(set(words)) < 2: |
|
return False |
|
|
|
|
|
for word in words: |
|
if len(word) == 1 and word.lower() not in ['i', 'a']: |
|
return False |
|
|
|
return True |
|
|
|
def generate_short_caption(description, max_retries=5): |
|
"""Generate a short, catchy caption (max 5 words) with better quality control""" |
|
if not description.strip() or generator is None: |
|
return "" |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
|
|
set_seed(random.randint(1, 10000)) |
|
|
|
|
|
prompts = [ |
|
f"Complete Instagram caption about {description[:40]}:", |
|
f"Short social media caption: {description[:40]}. Caption:", |
|
f"Trendy Instagram post about {description[:40]}:", |
|
f"Cool caption for {description[:40]}:", |
|
f"Fun social media post: {description[:40]}. Text:" |
|
] |
|
|
|
prompt = random.choice(prompts) |
|
|
|
generated = generator( |
|
prompt, |
|
max_length=len(prompt.split()) + 12, |
|
num_return_sequences=3, |
|
temperature=0.7, |
|
do_sample=True, |
|
top_p=0.85, |
|
repetition_penalty=1.4, |
|
pad_token_id=50256, |
|
truncation=True |
|
) |
|
|
|
|
|
for result in generated: |
|
full_text = result['generated_text'] |
|
caption = full_text.replace(prompt, "").strip() |
|
|
|
if caption: |
|
|
|
caption = re.sub(r'[^\w\s\']', '', caption) |
|
caption = ' '.join(caption.split()) |
|
|
|
|
|
for max_words in [5, 4, 3]: |
|
words = caption.split()[:max_words] |
|
if len(words) >= 2: |
|
test_caption = " ".join(words) |
|
if is_valid_caption(test_caption): |
|
return test_caption |
|
|
|
except Exception as e: |
|
print(f"Generation error (attempt {attempt + 1}): {e}") |
|
continue |
|
|
|
return "" |
|
|
|
def generate_instagram_captions(description): |
|
"""Generate 5 short, catchy Instagram captions (max 5 words each)""" |
|
if not description.strip(): |
|
return "Please enter a description of your photo! πΈ" |
|
|
|
|
|
keyword_captions = { |
|
|
|
"beach": ["Beach vibes only"], |
|
"sea": ["Sea you later"], |
|
"ocean": ["Ocean child wild heart"], |
|
"sand": ["Sandy toes salty kisses"], |
|
"sun": ["Chasing golden hour"], |
|
"sunset": ["Golden hour magic"], |
|
"swimming": ["Making waves today"], |
|
"surfing": ["Riding the waves"], |
|
|
|
|
|
"food": ["Good food good mood"], |
|
"coffee": ["But first coffee"], |
|
"pizza": ["Pizza my heart"], |
|
"brunch": ["Brunch squad assembled"], |
|
|
|
|
|
"gym": ["Sweat now shine later"], |
|
"workout": ["No pain no gain"], |
|
"yoga": ["Peace love and yoga"], |
|
"running": ["Run like you stole"], |
|
|
|
|
|
"travel": ["Wanderlust and city dust"], |
|
"adventure": ["Adventure is out there"], |
|
"hiking": ["Take only pictures"], |
|
"camping": ["Into the wild"], |
|
|
|
|
|
"party": ["Party like its"], |
|
"birthday": ["Another year wiser"], |
|
"celebration": ["Cheers to good times"], |
|
|
|
|
|
"friends": ["Squad goals right here"], |
|
"squad": ["Squad goals achieved"], |
|
"bestie": ["Partner in crime"], |
|
|
|
|
|
"outfit": ["Outfit on point"], |
|
"fashion": ["Fashion is art"], |
|
"style": ["Style is eternal"], |
|
|
|
|
|
"love": ["Love is everything"], |
|
"couple": ["Better together always"], |
|
"date": ["Date night vibes"], |
|
|
|
|
|
"nature": ["Nature is medicine"], |
|
"flowers": ["Bloom where planted"], |
|
"garden": ["Garden therapy session"], |
|
|
|
|
|
"work": ["Work hard play harder"], |
|
"success": ["Success tastes sweet"], |
|
"business": ["Business mode activated"] |
|
} |
|
|
|
|
|
description_lower = description.lower() |
|
keyword_specific_captions = [] |
|
|
|
for keyword, captions in keyword_captions.items(): |
|
if keyword in description_lower: |
|
keyword_specific_captions.extend(captions) |
|
|
|
|
|
ai_captions = [] |
|
if generator is not None: |
|
attempts = 0 |
|
while len(ai_captions) < 5 and attempts < 15: |
|
caption = generate_short_caption(description) |
|
if caption and caption not in ai_captions and is_valid_caption(caption): |
|
ai_captions.append(caption) |
|
attempts += 1 |
|
|
|
|
|
|
|
if keyword_specific_captions: |
|
all_captions = keyword_specific_captions + ai_captions |
|
else: |
|
all_captions = ai_captions |
|
|
|
|
|
seen = set() |
|
unique_captions = [] |
|
for caption in all_captions: |
|
if caption not in seen: |
|
seen.add(caption) |
|
unique_captions.append(caption) |
|
|
|
|
|
if len(unique_captions) < 5: |
|
basic_fallbacks = [ |
|
"Picture perfect moment", |
|
"Living my best life", |
|
"Good vibes only", |
|
"Making memories today", |
|
"Life is beautiful" |
|
] |
|
|
|
for fallback in basic_fallbacks: |
|
if fallback not in unique_captions: |
|
unique_captions.append(fallback) |
|
if len(unique_captions) >= 5: |
|
break |
|
|
|
|
|
selected_captions = unique_captions[:5] |
|
|
|
|
|
final_captions = [] |
|
for caption in selected_captions: |
|
words = caption.split()[:5] |
|
if len(words) >= 2: |
|
final_captions.append(" ".join(words)) |
|
|
|
|
|
result = "" |
|
emojis = ["π―", "β‘", "π", "π₯", "π«"] |
|
|
|
for i, caption in enumerate(final_captions[:5]): |
|
result += f"{emojis[i]} {caption}\n" |
|
|
|
return result.strip() |
|
|
|
|
|
instagram_css = """ |
|
/* Instagram color scheme and modern design */ |
|
:root { |
|
--ig-primary: #E4405F; |
|
--ig-secondary: #833AB4; |
|
--ig-tertiary: #F77737; |
|
--ig-quaternary: #FCAF45; |
|
--ig-blue: #405DE6; |
|
--ig-bg: #FAFAFA; |
|
--ig-text: #262626; |
|
--ig-border: #DBDBDB; |
|
--ig-light: #FFFFFF; |
|
} |
|
|
|
/* Main container styling */ |
|
.gradio-container { |
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; |
|
min-height: 100vh; |
|
font-family: 'Instagram Sans', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif !important; |
|
} |
|
|
|
/* Header styling */ |
|
.main-header { |
|
text-align: center; |
|
background: linear-gradient(45deg, var(--ig-primary), var(--ig-secondary)); |
|
-webkit-background-clip: text; |
|
-webkit-text-fill-color: transparent; |
|
background-clip: text; |
|
color: transparent; |
|
font-size: 3em; |
|
font-weight: 400; |
|
margin-bottom: 0.5em; |
|
padding: 1rem 0; |
|
background-color: rgba(255, 255, 255, 0.9); |
|
border-radius: 15px; |
|
margin: 1rem; |
|
box-shadow: 0 4px 20px rgba(0,0,0,0.1); |
|
} |
|
|
|
.subtitle { |
|
text-align: center; |
|
color: white; |
|
font-size: 1.2em; |
|
margin-bottom: 2em; |
|
text-shadow: 0 1px 2px rgba(0,0,0,0.3); |
|
font-weight: 300; |
|
} |
|
|
|
/* Card styling */ |
|
.card { |
|
background: rgba(255, 255, 255, 0.95) !important; |
|
border-radius: 20px !important; |
|
box-shadow: 0 8px 32px rgba(0,0,0,0.1) !important; |
|
backdrop-filter: blur(10px) !important; |
|
border: 1px solid rgba(255,255,255,0.2) !important; |
|
padding: 2rem !important; |
|
margin: 1rem !important; |
|
transition: all 0.3s ease !important; |
|
} |
|
|
|
.card:hover { |
|
transform: translateY(-5px) !important; |
|
box-shadow: 0 12px 40px rgba(0,0,0,0.15) !important; |
|
} |
|
|
|
/* Input styling */ |
|
.gr-textbox { |
|
border-radius: 15px !important; |
|
border: 2px solid var(--ig-border) !important; |
|
transition: all 0.3s ease !important; |
|
font-size: 1.1em !important; |
|
padding: 1rem !important; |
|
} |
|
|
|
.gr-textbox:focus { |
|
border-color: var(--ig-primary) !important; |
|
box-shadow: 0 0 0 3px rgba(228, 64, 95, 0.1) !important; |
|
} |
|
|
|
/* Button styling */ |
|
.generate-btn { |
|
background: linear-gradient(45deg, var(--ig-primary), var(--ig-secondary)) !important; |
|
border: none !important; |
|
color: white !important; |
|
font-weight: bold !important; |
|
font-size: 1.2em !important; |
|
padding: 15px 30px !important; |
|
border-radius: 25px !important; |
|
transition: all 0.3s ease !important; |
|
box-shadow: 0 4px 15px rgba(228, 64, 95, 0.3) !important; |
|
cursor: pointer !important; |
|
} |
|
|
|
.generate-btn:hover { |
|
transform: translateY(-3px) !important; |
|
box-shadow: 0 8px 25px rgba(228, 64, 95, 0.4) !important; |
|
background: linear-gradient(45deg, var(--ig-secondary), var(--ig-primary)) !important; |
|
} |
|
|
|
.generate-btn:active { |
|
transform: translateY(-1px) !important; |
|
} |
|
|
|
/* Output styling */ |
|
.output-box { |
|
background: rgba(255, 255, 255, 0.9) !important; |
|
border-radius: 15px !important; |
|
border: 2px solid var(--ig-border) !important; |
|
font-family: 'Instagram Sans', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif !important; |
|
font-size: 1.1em !important; |
|
line-height: 1.8 !important; |
|
padding: 1rem !important; |
|
} |
|
|
|
/* Tips section */ |
|
.tips-section { |
|
background: rgba(255, 255, 255, 0.9) !important; |
|
border-radius: 15px !important; |
|
padding: 1.5rem !important; |
|
margin-top: 2rem !important; |
|
text-align: center !important; |
|
box-shadow: 0 4px 20px rgba(0,0,0,0.1) !important; |
|
} |
|
|
|
.tips-section p { |
|
color: var(--ig-text) !important; |
|
font-size: 1em !important; |
|
margin: 0.5rem 0 !important; |
|
} |
|
|
|
/* Footer styling */ |
|
.footer { |
|
text-align: center; |
|
color: white; |
|
font-size: 0.9em; |
|
margin-top: 2rem; |
|
padding: 1rem; |
|
opacity: 0.8; |
|
} |
|
|
|
/* Mobile responsive */ |
|
@media (max-width: 768px) { |
|
.main-header { |
|
font-size: 2.2em !important; |
|
} |
|
|
|
.card { |
|
margin: 0.5rem !important; |
|
padding: 1.5rem !important; |
|
} |
|
|
|
.generate-btn { |
|
font-size: 1.1em !important; |
|
padding: 12px 25px !important; |
|
} |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks( |
|
theme=gr.themes.Soft( |
|
primary_hue="red", |
|
secondary_hue="purple", |
|
neutral_hue="gray" |
|
), |
|
css=instagram_css, |
|
title="InstaCap - AI Caption Generator" |
|
) as demo: |
|
|
|
|
|
gr.HTML('<h1 class="main-header">InstaCap β AI Caption Generator</h1>') |
|
gr.HTML('<p class="subtitle">Create short, catchy Instagram captions in seconds! Perfect for your social media posts β¨</p>') |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1, elem_classes="card"): |
|
description_input = gr.Textbox( |
|
label="π Describe your photo", |
|
placeholder="E.g., group photo with friends at sunset, coffee date with bestie, hiking adventure with squad...", |
|
lines=3, |
|
max_lines=5, |
|
elem_classes="gr-textbox" |
|
) |
|
|
|
generate_btn = gr.Button( |
|
"β¨ Generate Captions", |
|
variant="primary", |
|
elem_classes="generate-btn" |
|
) |
|
|
|
with gr.Column(scale=1, elem_classes="card"): |
|
caption_output = gr.Textbox( |
|
label="π― Your Instagram Captions", |
|
lines=8, |
|
max_lines=12, |
|
interactive=False, |
|
show_copy_button=True, |
|
elem_classes="output-box" |
|
) |
|
|
|
|
|
generate_btn.click( |
|
fn=generate_instagram_captions, |
|
inputs=description_input, |
|
outputs=caption_output |
|
) |
|
|
|
|
|
description_input.submit( |
|
fn=generate_instagram_captions, |
|
inputs=description_input, |
|
outputs=caption_output |
|
) |
|
|
|
gr.HTML( |
|
""" |
|
<div class="tips-section"> |
|
<p>π‘ <strong>Pro Tips for Better Captions:</strong></p> |
|
<p>π― Each caption is maximum 5 words for maximum impact</p> |
|
<p>β‘ Be specific in your description for better AI-generated captions</p> |
|
<p>π± Copy and paste directly to Instagram, TikTok, or any social platform</p> |
|
<p>π Click generate again for fresh new options</p> |
|
<p>β¨ Mix AI-generated captions with trending phrases</p> |
|
</div> |
|
""" |
|
) |
|
|
|
gr.HTML( |
|
""" |
|
<div class="footer"> |
|
<p>Made with β€οΈ for social media creators β’ Perfect for Instagram, TikTok, and more!</p> |
|
</div> |
|
""" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |