slliac commited on
Commit
d9851a0
·
verified ·
1 Parent(s): 2b715a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -30
app.py CHANGED
@@ -44,41 +44,43 @@ def translate_to_chinese(text):
44
  return translation
45
 
46
 
47
- # text2story - using mosaicml/mpt-7b-storywriter model for better stories
48
  def text2story(text):
49
  try:
50
- # Initialize the improved story generation pipeline
51
- generator = pipeline("text-generation", model="2173ars/llama-3-8b-Instruct-bnb-4bit-personal-shortstory")
52
-
53
- # Create a prompt for the story
54
  prompt = f"{text}"
55
-
56
- story = generator(prompt,
57
- min_length=100,
58
- max_length=130,
59
- num_return_sequences=1,
60
- top_k=50,
61
- top_p=0.92,
62
- no_repeat_ngram_size=3,
63
- temperature=0.5,
64
- repetition_penalty=1.3)[0]['generated_text']
65
-
66
- return story.replace(prompt, "").strip()
67
-
 
68
  except Exception as e:
69
  # Fallback to simpler model if the advanced one fails
70
- fallback_generator = pipeline('text-generation', model='gpt2')
71
- fallback_prompt = f"{text} "
72
- fallback_story = fallback_generator(fallback_prompt,
73
- min_length=100,
74
- max_length=130,
75
- num_return_sequences=1,
76
- top_k=50,
77
- top_p=0.92,
78
- no_repeat_ngram_size=3,
79
- temperature=0.5,
80
- repetition_penalty=1.3)[0]['generated_text']
81
- return fallback_story.replace(fallback_prompt, "").strip()
 
 
 
 
82
 
83
  def load_css(css_file):
84
  with open(css_file) as f:
 
44
  return translation
45
 
46
 
47
+ # text2story - using llama-3-8b-Instruct-bnb-4bit-personal-shortstory model for better stories
48
  def text2story(text):
49
  try:
50
+ model = AutoModel.from_pretrained("2173ars/llama-3-8b-Instruct-bnb-4bit-personal-shortstory")
 
 
 
51
  prompt = f"{text}"
52
+ tokenizer = AutoTokenizer.from_pretrained("2173ars/llama-3-8b-Instruct-bnb-4bit-personal-shortstory")
53
+ inputs = tokenizer(prompt, return_tensors="pt")
54
+ outputs = model.generate(
55
+ inputs.input_ids,
56
+ max_new_tokens=250,
57
+ temperature=0.7,
58
+ top_p=0.9,
59
+ top_k=40,
60
+ repetition_penalty=1.2,
61
+ do_sample=True,
62
+ pad_token_id=tokenizer.eos_token_id
63
+ )
64
+ story = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
+ return story
66
  except Exception as e:
67
  # Fallback to simpler model if the advanced one fails
68
+ fallback_generator = AutoModel.from_pretrained("openai-community/gpt2")
69
+ fallback_prompt = f"{text}"
70
+ tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
71
+ inputs = tokenizer(prompt, return_tensors="pt")
72
+ fallback_story = model.generate(
73
+ inputs.input_ids,
74
+ max_new_tokens=250,
75
+ temperature=0.7,
76
+ top_p=0.9,
77
+ top_k=40,
78
+ repetition_penalty=1.2,
79
+ do_sample=True,
80
+ pad_token_id=tokenizer.eos_token_id
81
+ )
82
+ fallback_story = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
+ return fallback_story
84
 
85
  def load_css(css_file):
86
  with open(css_file) as f: