assignment3 / function.py
leoxia711's picture
Update function.py
6ecc8b7 verified
from transformers import pipeline
import torch
from datasets import load_dataset
import soundfile as sf
from transformers import AutoTokenizer, AutoModelForCausalLM, VitsModel
import numpy as np
import re
# Convert image to text description using a vision-language model
def img2text(url):
image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
text = image_to_text_model(url)[0]["generated_text"]
# Remove art-related words to make the description more neutral
for word in ["illustration", "drawing", "painting", "rendering"]:
text = text.replace(word, "").strip()
return text
# Generate a short story from a given text prompt
def text2story(caption):
"""
Generates a child-friendly story (50–100 words) from a given image caption.
Ensures it avoids dark/adult themes and encourages a whimsical tone.
"""
tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator")
model = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt2-genre-story-generator")
# Prompt to guide the model
prompt = (
f"Write a heartwarming story for a child. "
f"Must use {caption} as places and characters in the story. "
f"\n\nStory:"
)
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
inputs.input_ids,
max_length=180,
do_sample=True,
top_p=0.95,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove prompt prefix if present
if "Story:" in output_text:
output_text = output_text.split("Story:")[-1].strip()
# Limit to ~100 words, but try to cut at a sentence ending (., !, ?)
word_list = output_text.split()
cut_text = " ".join(word_list[:130]) # give buffer for sentence endings
sentences = re.split(r'(?<=[.!?])\s+', cut_text)
trimmed_story = ""
total_words = 0
for sentence in sentences:
sentence = sentence.strip()
word_count = len(sentence.split())
if total_words + word_count > 100:
break
if sentence:
trimmed_story += sentence + " "
total_words += word_count
story = trimmed_story.strip()
# If no sentence-ending punctuation found, just force cut at 100 words
if not story:
story = " ".join(word_list[:100])
if not story.endswith(('.', '!', '?')):
story += "."
return story
# Convert text story into audio using a speech synthesis model
def text2audio(story_text):
model = VitsModel.from_pretrained("facebook/mms-tts-eng")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
inputs = tokenizer(story_text, return_tensors="pt")
# Important: convert input IDs to LongTensor to avoid runtime error
inputs["input_ids"] = inputs["input_ids"].long()
with torch.no_grad():
output = model(**inputs).waveform
# Convert tensor to NumPy array and save it as a .wav file
audio_np = output.squeeze().cpu().numpy()
output_path = "generated_audio.wav"
sf.write(output_path, audio_np, 22050)
return output_path