Spaces:
Build error
Build error
bankholdup
commited on
Commit
•
931dc99
1
Parent(s):
67c2dd6
Update app.py
Browse files
app.py
CHANGED
@@ -9,8 +9,8 @@ from transformers import pipeline
|
|
9 |
@st.cache(allow_output_mutation=True)
|
10 |
def load_model():
|
11 |
model_ckpt = "bankholdup/rugpt3_song_writer"
|
12 |
-
tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt
|
13 |
-
model = GPT2LMHeadModel.from_pretrained(model_ckpt
|
14 |
return tokenizer, model
|
15 |
|
16 |
def set_seed(args):
|
@@ -23,24 +23,43 @@ def set_seed(args):
|
|
23 |
|
24 |
title = st.title("Loading model")
|
25 |
tokenizer, model = load_model()
|
26 |
-
text_generation = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
27 |
title.title("ruGPT3 Song Writer")
|
28 |
context = st.text_input("Введите начало песни", "Как дела? Как дела? Это новый кадиллак")
|
29 |
|
30 |
if st.button("Поехали", help="Может занять какое-то время"):
|
31 |
st.title(f"{context}")
|
32 |
prefix_text = f"{context}"
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
@st.cache(allow_output_mutation=True)
|
10 |
def load_model():
|
11 |
model_ckpt = "bankholdup/rugpt3_song_writer"
|
12 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt)
|
13 |
+
model = GPT2LMHeadModel.from_pretrained(model_ckpt)
|
14 |
return tokenizer, model
|
15 |
|
16 |
def set_seed(args):
|
|
|
23 |
|
24 |
title = st.title("Loading model")
|
25 |
tokenizer, model = load_model()
|
|
|
26 |
title.title("ruGPT3 Song Writer")
|
27 |
context = st.text_input("Введите начало песни", "Как дела? Как дела? Это новый кадиллак")
|
28 |
|
29 |
if st.button("Поехали", help="Может занять какое-то время"):
|
30 |
st.title(f"{context}")
|
31 |
prefix_text = f"{context}"
|
32 |
+
encoded_prompt = tokenizer.encode(prefix_text, add_special_tokens=False, return_tensors="pt")
|
33 |
+
output_sequences = model.generate(
|
34 |
+
input_ids=encoded_prompt,
|
35 |
+
max_length=200 + len(encoded_prompt[0]),
|
36 |
+
temperature=0.95,
|
37 |
+
top_k=50,
|
38 |
+
top_p=0.95,
|
39 |
+
repetition_penalty=1.0,
|
40 |
+
do_sample=True,
|
41 |
+
num_return_sequences=1,
|
42 |
+
)
|
43 |
+
|
44 |
+
# Remove the batch dimension when returning multiple sequences
|
45 |
+
if len(output_sequences.shape) > 2:
|
46 |
+
output_sequences.squeeze_()
|
47 |
+
|
48 |
+
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
49 |
+
print("ruGPT:".format(generated_sequence_idx + 1))
|
50 |
+
generated_sequence = generated_sequence.tolist()
|
51 |
+
|
52 |
+
# Decode text
|
53 |
+
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
54 |
+
|
55 |
+
# Remove all text after the stop token
|
56 |
+
text = text[: text.find(args.stop_token) if args.stop_token else None]
|
57 |
+
|
58 |
+
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
|
59 |
+
total_sequence = (
|
60 |
+
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
|
61 |
+
)
|
62 |
+
|
63 |
+
generated_sequences.append(total_sequence)
|
64 |
+
# os.system('clear')
|
65 |
+
st.write(total_sequence)
|