vvv-knyazeva commited on
Commit
c2014f8
1 Parent(s): b607f03

Delete gpt_v1.py

Browse files
Files changed (1) hide show
  1. gpt_v1.py +0 -47
gpt_v1.py DELETED
@@ -1,47 +0,0 @@
1
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
- import torch
3
- import streamlit as st
4
-
5
- model = GPT2LMHeadModel.from_pretrained(
6
- 'sberbank-ai/rugpt3small_based_on_gpt2',
7
- output_attentions = False,
8
- output_hidden_states = False,
9
- )
10
-
11
- tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
12
-
13
-
14
- # Вешаем сохраненные веса на нашу модель
15
- model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
16
-
17
- prompt = st.text_input('Введите текст prompt:')
18
- length = st.slider('Длина генерируемой последовательности:', 1, 256, 16)
19
- num_samples = st.slider('Число генераций:', 1, 4, 1)
20
- temperature = st.slider('Температура:', 1.0, 5.0, 1.0)
21
-
22
-
23
- def generate_text(model, tokenizer, prompt, length, num_samples, temperature):
24
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
25
- output_sequences = model.generate(
26
- input_ids=input_ids,
27
- max_length=length,
28
- num_return_sequences=num_samples,
29
- temperature=temperature
30
- )
31
-
32
- generated_texts = []
33
- for output_sequence in output_sequences:
34
- generated_text = tokenizer.decode(output_sequence, clean_up_tokenization_spaces=True)
35
- generated_texts.append(generated_text)
36
-
37
- return generated_texts
38
-
39
-
40
- if st.button('Сгенерировать текст'):
41
- generated_texts = generate_text(model, tokenizer, prompt, length, num_samples, temperature)
42
- for i, text in enumerate(generated_texts):
43
- st.write(f'Текст {i+1}:')
44
- st.write(text)
45
-
46
-
47
-