Ilvir commited on
Commit
e57f431
·
1 Parent(s): 9a261af

Update pages/gpt (1).py

Browse files
Files changed (1) hide show
  1. pages/gpt (1).py +8 -71
pages/gpt (1).py CHANGED
@@ -1,73 +1,10 @@
1
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
  import streamlit as st
3
- import torch
4
- import textwrap
5
- import plotly.express as px
6
-
7
- from streamlit_extras.let_it_rain import rain
8
-
9
- rain(
10
- emoji="⭐",
11
- font_size=54,
12
- falling_speed=5,
13
- animation_length="infinite",
14
- )
15
-
16
- st.header(':green[Text generation by GPT2 model]')
17
-
18
- tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
19
- model = GPT2LMHeadModel.from_pretrained(
20
- 'sberbank-ai/rugpt3small_based_on_gpt2',
21
- output_attentions = False,
22
- output_hidden_states = False,
23
- )
24
-
25
- model.load_state_dict(torch.load('models/model.pt', map_location=torch.device('cpu')))
26
-
27
-
28
- length = st.sidebar.slider('**Generated sequence length:**', 8, 256, 15)
29
- if length > 100:
30
- st.warning("This is very hard for me, please have pity on me. Could you lower the value?", icon="🤖")
31
- num_samples = st.sidebar.slider('**Number of generations:**', 1, 10, 1)
32
- if num_samples > 4:
33
- st.warning("OH MY ..., I have to work late again!!! Could you lower the value?", icon="🤖")
34
- temperature = st.sidebar.slider('**Temperature:**', 0.1, 10.0, 3.0)
35
- if temperature > 6.0:
36
- st.info('What? You want to get some kind of bullshit as a result? Turn down the temperature', icon="🤖")
37
- top_k = st.sidebar.slider('**Number of most likely generation words:**', 10, 200, 50)
38
- top_p = st.sidebar.slider('**Minimum total probability of top words:**', 0.4, 1.0, 0.9)
39
-
40
-
41
- prompt = st.text_input('**Enter text 👇:**')
42
- if st.button('**Generate text**'):
43
- image_container = st.empty()
44
- image_container.image("pict/wait.jpeg", caption="that's so long!!!", use_column_width=True)
45
- with torch.inference_mode():
46
- prompt = tokenizer.encode(prompt, return_tensors='pt')
47
- out = model.generate(
48
- input_ids=prompt,
49
- max_length=length,
50
- num_beams=8,
51
- do_sample=True,
52
- temperature=temperature,
53
- top_k=top_k,
54
- top_p=top_p,
55
- no_repeat_ngram_size=3,
56
- num_return_sequences=num_samples,
57
- ).cpu().numpy()
58
- image_container.empty()
59
- st.write('**_Результат_** 👇')
60
- for i, out_ in enumerate(out):
61
- # audio_file = open('pict/pole-chudes-priz.mp3', 'rb')
62
- # audio_bytes = audio_file.read()
63
- # st.audio(audio_bytes, format='audio/mp3')
64
-
65
- with st.expander(f'Текст {i+1}:'):
66
- st.write(textwrap.fill(tokenizer.decode(out_), 100))
67
- st.image("pict/wow.png")
68
-
69
-
70
-
71
-
72
-
73
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ txt = st.text_area('Text to analyze', '''
4
+ It was the best of times, it was the worst of times, it was
5
+ the age of wisdom, it was the age of foolishness, it was
6
+ the epoch of belief, it was the epoch of incredulity, it
7
+ was the season of Light, it was the season of Darkness, it
8
+ was the spring of hope, it was the winter of despair, (...)
9
+ ''')
10
+ st.write('Sentiment:', run_sentiment_analysis(txt))