File size: 3,234 Bytes
9d6d02a
 
677581f
 
446a484
fba8d44
a500113
62e1e6a
677581f
9d6d02a
677581f
9d6d02a
677581f
 
beddf4f
 
9d6d02a
677581f
03b1165
677581f
76542b7
 
677581f
 
76542b7
677581f
 
 
 
7bdd7ac
67d33bf
 
9362e79
a5b8345
0979135
677581f
ac3282b
677581f
ac3282b
14bac1b
ac3282b
 
677581f
 
ac3282b
 
677581f
ac3282b
 
 
 
 
677581f
ac3282b
677581f
ac3282b
 
677581f
ac3282b
 
 
 
 
 
76542b7
ac3282b
 
 
 
 
991a2bc
677581f
76542b7
ac3282b
 
 
991a2bc
76542b7
b9a776a
677581f
03b1165
677581f
 
 
 
6ecc8b7
 
2c0c53b
677581f
8236e9e
677581f
8236e9e
6ecc8b7
3530cb2
6ecc8b7
 
a4176f4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from transformers import pipeline
import torch
from datasets import load_dataset
import soundfile as sf  
from transformers import AutoTokenizer, AutoModelForCausalLM, VitsModel
import numpy as np
import re

# Convert image to text description using a vision-language model
def img2text(url):
    image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
    text = image_to_text_model(url)[0]["generated_text"]

    # Remove art-related words to make the description more neutral
    for word in ["illustration", "drawing", "painting", "rendering"]:
        text = text.replace(word, "").strip()

    return text

# Generate a short story from a given text prompt
def text2story(caption):
    """
    Generates a child-friendly story (50–100 words) from a given image caption.
    Ensures it avoids dark/adult themes and encourages a whimsical tone.
    """
    tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator")
    model = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt2-genre-story-generator")

    # Prompt to guide the model
    prompt = (
    f"Write a heartwarming story for a child. "
    f"Must use {caption} as places and characters in the story. "
    f"\n\nStory:"
    )

    inputs = tokenizer(prompt, return_tensors="pt")

    outputs = model.generate(
        inputs.input_ids,
        max_length=180,
        do_sample=True,
        top_p=0.95,
        temperature=0.7,
        pad_token_id=tokenizer.eos_token_id
    )

    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Remove prompt prefix if present
    if "Story:" in output_text:
        output_text = output_text.split("Story:")[-1].strip()

    # Limit to ~100 words, but try to cut at a sentence ending (., !, ?)
    word_list = output_text.split()
    cut_text = " ".join(word_list[:130])  # give buffer for sentence endings

    sentences = re.split(r'(?<=[.!?])\s+', cut_text)

    trimmed_story = ""
    total_words = 0
    for sentence in sentences:
        sentence = sentence.strip()
        word_count = len(sentence.split())
        if total_words + word_count > 100:
            break
        if sentence:
            trimmed_story += sentence + " "
            total_words += word_count

    story = trimmed_story.strip()

    # If no sentence-ending punctuation found, just force cut at 100 words
    if not story:
        story = " ".join(word_list[:100])
        if not story.endswith(('.', '!', '?')):
            story += "."

    return story

# Convert text story into audio using a speech synthesis model
def text2audio(story_text):
    model = VitsModel.from_pretrained("facebook/mms-tts-eng")
    tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")

    inputs = tokenizer(story_text, return_tensors="pt")

    # Important: convert input IDs to LongTensor to avoid runtime error
    inputs["input_ids"] = inputs["input_ids"].long()

    with torch.no_grad():
        output = model(**inputs).waveform

    # Convert tensor to NumPy array and save it as a .wav file
    audio_np = output.squeeze().cpu().numpy()
    output_path = "generated_audio.wav"
    sf.write(output_path, audio_np, 22050)

    return output_path