nan-dre commited on
Commit
f7d62af
β€’
1 Parent(s): 186b882

Initial commit

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from time import perf_counter
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ MODEL = 'nan-dre/maneleGPT-medium'
7
+ TOKENIZER = 'nan-dre/maneleGPT-medium'
8
+ MAX_LENGTH = 256
9
+
10
+ st.set_page_config(
11
+ page_title="ManeleGPT",
12
+ page_icon="πŸ‡·πŸ‡΄",
13
+ layout="centered"
14
+ )
15
+
16
+ def typical_sampling(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, temperature, typical_p):
17
+ return model.generate(
18
+ input_ids=input_ids,
19
+ attention_mask=attention_mask,
20
+ no_repeat_ngram_size=no_repeat_ngram_size,
21
+ max_length=max_length,
22
+ do_sample=True,
23
+ temperature=temperature,
24
+ typical_p=typical_p,
25
+ top_k=0
26
+ )
27
+
28
+
29
+ @st.cache_resource
30
+ def setModel():
31
+ model = AutoModelForCausalLM.from_pretrained(MODEL)
32
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
33
+ return model, tokenizer
34
+
35
+ st.header("ManeleGPT")
36
+ temperature = st.slider(label="Temperatura", min_value=0.01, max_value=1.0, value=0.5, step=0.01)
37
+ input = st.text_input(label="Cu ce vers sa inceapa maneaua?", value="", key="seed")
38
+
39
+ if input:
40
+ model, tokenizer = setModel()
41
+
42
+ tokenized_text = tokenizer(input, add_special_tokens=False, return_tensors="pt")
43
+
44
+ if len(tokenized_text.input_ids[0]) + MAX_LENGTH > 512: # need to keep less words
45
+ keep_last = 512 - MAX_LENGTH
46
+ print(f"keep last: {keep_last}")
47
+ input_ids, attention_mask = tokenized_text.input_ids[0][-keep_last:], tokenized_text.attention_mask[0][-keep_last:]
48
+ previous_ids = tokenized_text.input_ids[0][:keep_last]
49
+ st.warning(f"kept last {keep_last}")
50
+ else:
51
+ input_ids, attention_mask = tokenized_text.input_ids[0], tokenized_text.attention_mask[0]
52
+ previous_ids = None
53
+
54
+ length = min(512, len(input_ids) + MAX_LENGTH)
55
+ timer_mark = perf_counter()
56
+ output = typical_sampling(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngram_size=2, max_length=MAX_LENGTH, temperature=temperature, typical_p=1)
57
+ details = f"Text generated in {perf_counter()-timer_mark:.2f}s"
58
+
59
+
60
+ if previous_ids is not None:
61
+ print(f"\nConcat prev id: "+tokenizer.decode(previous_ids, skip_special_tokens=True))
62
+ print(f"\nWith current decode: " + tokenizer.decode(output[0], skip_special_tokens=True))
63
+ new_text = tokenizer.decode(torch.cat([previous_ids, output[0]], dim=-1), skip_special_tokens=True)
64
+ else:
65
+ new_text = tokenizer.decode(output[0], skip_special_tokens=True)
66
+
67
+ st.text(new_text)