Spaces:
Build error
Build error
bankholdup
commited on
Commit
•
dc60d3c
1
Parent(s):
41e6718
Update app.py
Browse files
app.py
CHANGED
@@ -13,13 +13,10 @@ def load_model():
|
|
13 |
model = GPT2LMHeadModel.from_pretrained(model_ckpt)
|
14 |
return tokenizer, model
|
15 |
|
16 |
-
def set_seed(
|
17 |
-
rd = np.random.randint(
|
18 |
-
print('seed =', rd)
|
19 |
np.random.seed(rd)
|
20 |
torch.manual_seed(rd)
|
21 |
-
if args.n_gpu > 0:
|
22 |
-
torch.cuda.manual_seed_all(rd)
|
23 |
|
24 |
title = st.title("Загрузка модели")
|
25 |
tokenizer, model = load_model()
|
@@ -28,38 +25,34 @@ context = st.text_input("Введите начало песни", "Как дел
|
|
28 |
generated_sequences = []
|
29 |
|
30 |
if st.button("Поехали", help="Может занять какое-то время"):
|
|
|
31 |
prompt_text = f"{context}"
|
32 |
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
33 |
output_sequences = model.generate(
|
34 |
input_ids=encoded_prompt,
|
35 |
-
max_length=
|
36 |
-
temperature=
|
37 |
top_k=50,
|
38 |
top_p=0.95,
|
39 |
repetition_penalty=1.0,
|
40 |
do_sample=True,
|
41 |
num_return_sequences=1,
|
42 |
)
|
43 |
-
|
44 |
-
# Remove the batch dimension when returning multiple sequences
|
45 |
if len(output_sequences.shape) > 2:
|
46 |
output_sequences.squeeze_()
|
47 |
|
48 |
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
49 |
print("ruGPT:".format(generated_sequence_idx + 1))
|
50 |
generated_sequence = generated_sequence.tolist()
|
51 |
-
|
52 |
-
# Decode text
|
53 |
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
54 |
|
55 |
-
# Remove all text after the stop token
|
56 |
text = text[: text.find("</s>") if "</s>" else None]
|
57 |
-
|
58 |
-
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
|
59 |
total_sequence = (
|
60 |
-
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
|
61 |
)
|
62 |
|
63 |
-
generated_sequences.append(total_sequence)
|
64 |
# os.system('clear')
|
65 |
st.write(total_sequence)
|
|
|
13 |
model = GPT2LMHeadModel.from_pretrained(model_ckpt)
|
14 |
return tokenizer, model
|
15 |
|
16 |
+
def set_seed(rng=100000):
|
17 |
+
rd = np.random.randint(rng)
|
|
|
18 |
np.random.seed(rd)
|
19 |
torch.manual_seed(rd)
|
|
|
|
|
20 |
|
21 |
title = st.title("Загрузка модели")
|
22 |
tokenizer, model = load_model()
|
|
|
25 |
generated_sequences = []
|
26 |
|
27 |
if st.button("Поехали", help="Может занять какое-то время"):
|
28 |
+
set_seed()
|
29 |
prompt_text = f"{context}"
|
30 |
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
31 |
output_sequences = model.generate(
|
32 |
input_ids=encoded_prompt,
|
33 |
+
max_length=250 + len(encoded_prompt[0]),
|
34 |
+
temperature=1.95,
|
35 |
top_k=50,
|
36 |
top_p=0.95,
|
37 |
repetition_penalty=1.0,
|
38 |
do_sample=True,
|
39 |
num_return_sequences=1,
|
40 |
)
|
|
|
|
|
41 |
if len(output_sequences.shape) > 2:
|
42 |
output_sequences.squeeze_()
|
43 |
|
44 |
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
45 |
print("ruGPT:".format(generated_sequence_idx + 1))
|
46 |
generated_sequence = generated_sequence.tolist()
|
47 |
+
|
|
|
48 |
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
49 |
|
|
|
50 |
text = text[: text.find("</s>") if "</s>" else None]
|
51 |
+
|
|
|
52 |
total_sequence = (
|
53 |
+
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
|
54 |
)
|
55 |
|
56 |
+
# generated_sequences.append(total_sequence)
|
57 |
# os.system('clear')
|
58 |
st.write(total_sequence)
|