yirmibesogluz commited on
Commit
a2b1396
1 Parent(s): 6a67c51

Added config params to Turna

Browse files
Files changed (1) hide show
  1. apps/home.py +52 -7
apps/home.py CHANGED
@@ -5,8 +5,6 @@ from transformers import pipeline
5
  import os
6
  from .utils import query
7
 
8
- API_URL = "https://api-inference.huggingface.co/models/boun-tabi-LMG/TURNA"
9
-
10
  def write():
11
  st.markdown(
12
  """
@@ -35,9 +33,56 @@ def write():
35
 
36
  #st.title('Turkish Language Generation')
37
  #st.write('...with Turna')
38
- input_text = st.text_area(label='Enter a text: ', height=100,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  value="Türkiye'nin başkenti neresidir?")
40
- if st.button("Generate"):
41
- with st.spinner('Generating...'):
42
- output = query(input_text, API_URL)
43
- st.success(output)
 
 
 
 
5
  import os
6
  from .utils import query
7
 
 
 
8
  def write():
9
  st.markdown(
10
  """
 
33
 
34
  #st.title('Turkish Language Generation')
35
  #st.write('...with Turna')
36
+
37
+ # Sidebar
38
+
39
+ # Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py
40
+ st.sidebar.subheader("Configurable parameters")
41
+
42
+ max_new_tokens = st.sidebar.number_input(
43
+ "Maximum length",
44
+ min_value=0,
45
+ max_value=512,
46
+ value=128,
47
+ help="The maximum length of the sequence to be generated.",
48
+ )
49
+ length_penalty = st.sidebar.number_input(
50
+ "Length penalty",
51
+ value=1.0,
52
+ help=" length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. ",
53
+ )
54
+ do_sample = st.sidebar.selectbox(
55
+ "Sampling?",
56
+ (True, False),
57
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
58
+ )
59
+ num_beams = st.sidebar.number_input(
60
+ "Number of beams",
61
+ min_value=1,
62
+ max_value=10,
63
+ value=3,
64
+ help="The number of beams to use for beam search.",
65
+ )
66
+ repetition_penalty = st.sidebar.number_input(
67
+ "Repetition Penalty",
68
+ min_value=0.0,
69
+ value=3.0,
70
+ step=0.1,
71
+ help="The parameter for repetition penalty. 1.0 means no penalty",
72
+ )
73
+ no_repeat_ngram_size = st.sidebar.number_input(
74
+ "No Repeat N-Gram Size",
75
+ min_value=0,
76
+ value=3,
77
+ help="If set to int > 0, all ngrams of that size can only occur once.",
78
+ )
79
+
80
+ input_text = st.text_area(label='Enter a text: ', height=100,
81
  value="Türkiye'nin başkenti neresidir?")
82
+ url = "https://api-inference.huggingface.co/models/boun-tabi-LMG/TURNA"
83
+ params = {"length_penalty": length_penalty, "no_repeat_ngram_size": no_repeat_ngram_size, "max_new_tokens": max_new_tokens,
84
+ "do_sample":do_sample, "num_beams":num_beams, "repetition_penalty":repetition_penalty }
85
+ if st.button("Generate"):
86
+ with st.spinner('Generating...'):
87
+ output = query(input_text, url, params)
88
+ st.success(output)