leoxia711 commited on
Commit
03b1165
·
verified ·
1 Parent(s): fa695ee

Update function.py

Browse files
Files changed (1) hide show
  1. function.py +44 -36
function.py CHANGED
@@ -1,62 +1,74 @@
1
  from transformers import pipeline
2
  import torch
3
- from datasets import load_dataset
4
- import soundfile as sf
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, VitsModel
6
  import numpy as np
7
  import re
8
 
9
- # Convert image to text description using a vision-language model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def img2text(url):
11
- image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
12
  text = image_to_text_model(url)[0]["generated_text"]
13
-
14
- # Remove art-related words to make the description more neutral
15
  for word in ["illustration", "drawing", "painting", "rendering"]:
16
  text = text.replace(word, "").strip()
17
-
18
  return text
19
 
20
- # Generate a short story from a given text prompt
 
 
 
21
  def text2story(caption):
22
  """
23
- Generates a child-friendly story (50–100 words) from a given image caption.
24
- Ensures it avoids dark/adult themes and encourages a whimsical tone.
25
  """
26
- tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
27
- model = AutoModelForCausalLM.from_pretrained("distilgpt2")
28
-
29
- # Prompt to guide the model
30
  prompt = (
31
- f"Write a short, cheerful story for a 5-year-old based entirely on: {caption}. "
32
- f"Make it magical, fun, and avoid anything scary or sad.\n\nStory:"
33
  )
34
 
35
- inputs = tokenizer(prompt, return_tensors="pt")
36
 
37
- outputs = model.generate(
38
  inputs.input_ids,
39
- max_length=150,
40
  do_sample=True,
41
  top_p=0.95,
42
- temperature=0.9,
43
- pad_token_id=tokenizer.eos_token_id
44
  )
45
 
46
- output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
47
 
48
  # Remove prompt prefix if present
49
  if "Story:" in output_text:
50
  output_text = output_text.split("Story:")[-1].strip()
51
 
52
- # Limit to ~100 words, but try to cut at a sentence ending (., !, ?)
53
  word_list = output_text.split()
54
- cut_text = " ".join(word_list[:130]) # give buffer for sentence endings
55
 
56
  sentences = re.split(r'(?<=[.!?])\s+', cut_text)
57
-
58
  trimmed_story = ""
59
  total_words = 0
 
60
  for sentence in sentences:
61
  sentence = sentence.strip()
62
  word_count = len(sentence.split())
@@ -68,7 +80,6 @@ def text2story(caption):
68
 
69
  story = trimmed_story.strip()
70
 
71
- # If no sentence-ending punctuation found, just force cut at 100 words
72
  if not story:
73
  story = " ".join(word_list[:100])
74
  if not story.endswith(('.', '!', '?')):
@@ -76,20 +87,17 @@ def text2story(caption):
76
 
77
  return story
78
 
79
- # Convert text story into audio using a speech synthesis model
80
- def text2audio(story_text):
81
- model = VitsModel.from_pretrained("facebook/mms-tts-eng")
82
- tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
83
-
84
- inputs = tokenizer(story_text, return_tensors="pt")
85
 
86
- # Important: convert input IDs to LongTensor to avoid runtime error
87
- inputs["input_ids"] = inputs["input_ids"].long()
 
 
 
 
88
 
89
  with torch.no_grad():
90
- output = model(**inputs).waveform
91
 
92
- # Convert tensor to NumPy array and save it as a .wav file
93
  audio_np = output.squeeze().cpu().numpy()
94
  output_path = "generated_audio.wav"
95
  sf.write(output_path, audio_np, 22050)
 
1
  from transformers import pipeline
2
  import torch
3
+ import soundfile as sf
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, VitsModel
5
  import numpy as np
6
  import re
7
 
8
+ # ====================
9
+ # Load models globally
10
+ # ====================
11
+
12
+ # Image captioning pipeline
13
+ image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
14
+
15
+ # Story generation model (DistilGPT2)
16
+ story_tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
17
+ story_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
18
+
19
+ # Text-to-speech model (Facebook MMS)
20
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
21
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
22
+
23
+
24
+ # ====================
25
+ # Function 1: Image → Text
26
+ # ====================
27
  def img2text(url):
 
28
  text = image_to_text_model(url)[0]["generated_text"]
 
 
29
  for word in ["illustration", "drawing", "painting", "rendering"]:
30
  text = text.replace(word, "").strip()
 
31
  return text
32
 
33
+
34
+ # ====================
35
+ # Function 2: Text → Story
36
+ # ====================
37
  def text2story(caption):
38
  """
39
+ Generates a child-friendly story (up to 100 words) from a given image caption.
40
+ Uses DistilGPT2 for fast story generation.
41
  """
 
 
 
 
42
  prompt = (
43
+ f"Write a short, cheerful story for a 5-year-old. The story must mention {caption}. "
44
+ f"The characters and location should be entirely based on {caption}.\n\nStory:"
45
  )
46
 
47
+ inputs = story_tokenizer(prompt, return_tensors="pt")
48
 
49
+ outputs = story_model.generate(
50
  inputs.input_ids,
51
+ max_length=120, # faster than 200, still enough for ~90 words
52
  do_sample=True,
53
  top_p=0.95,
54
+ temperature=0.8,
55
+ pad_token_id=story_tokenizer.eos_token_id
56
  )
57
 
58
+ output_text = story_tokenizer.decode(outputs[0], skip_special_tokens=True)
59
 
60
  # Remove prompt prefix if present
61
  if "Story:" in output_text:
62
  output_text = output_text.split("Story:")[-1].strip()
63
 
64
+ # Trim to 100 words max, cutting at sentence boundaries
65
  word_list = output_text.split()
66
+ cut_text = " ".join(word_list[:130]) # small buffer
67
 
68
  sentences = re.split(r'(?<=[.!?])\s+', cut_text)
 
69
  trimmed_story = ""
70
  total_words = 0
71
+
72
  for sentence in sentences:
73
  sentence = sentence.strip()
74
  word_count = len(sentence.split())
 
80
 
81
  story = trimmed_story.strip()
82
 
 
83
  if not story:
84
  story = " ".join(word_list[:100])
85
  if not story.endswith(('.', '!', '?')):
 
87
 
88
  return story
89
 
 
 
 
 
 
 
90
 
91
+ # ====================
92
+ # Function 3: Story → Audio
93
+ # ====================
94
+ def text2audio(story_text):
95
+ inputs = tts_tokenizer(story_text, return_tensors="pt")
96
+ inputs["input_ids"] = inputs["input_ids"].long() # Ensure correct type for VitsModel
97
 
98
  with torch.no_grad():
99
+ output = tts_model(**inputs).waveform
100
 
 
101
  audio_np = output.squeeze().cpu().numpy()
102
  output_path = "generated_audio.wav"
103
  sf.write(output_path, audio_np, 22050)