MorenoLaQuatra commited on
Commit
7c7422b
1 Parent(s): e6612b7

Added parameters for a better space

Browse files
Files changed (1) hide show
  1. app.py +29 -13
app.py CHANGED
@@ -2,15 +2,15 @@ import gradio as gr
2
  from transformers import pipeline
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
- #tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-1.3b")
6
- #model = AutoModelForCausalLM.from_pretrained("facebook/galactica-1.3b")
7
- #text2text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer, num_workers=2)
8
 
9
- def predict(text):
10
  text = text.strip()
11
- out_text = text2text_generator(text, max_length=128,
12
- temperature=0.7,
13
- do_sample=True,
14
  eos_token_id = tokenizer.eos_token_id,
15
  bos_token_id = tokenizer.bos_token_id,
16
  pad_token_id = tokenizer.pad_token_id,
@@ -22,12 +22,28 @@ def predict(text):
22
  return out_text
23
 
24
  iface = gr.Interface.load(
25
- "huggingface/facebook/galactica-1.3b",
26
- #fn=predict,
27
- #inputs=gr.Textbox(lines=10),
28
- #outputs=gr.HTML(),
29
- description="Galactica",
30
- examples=[["The attention mechanism in LLM is"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  )
32
 
33
  iface.launch()
2
  from transformers import pipeline
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-1.3b")
6
+ model = AutoModelForCausalLM.from_pretrained("facebook/galactica-1.3b")
7
+ text2text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer, num_workers=2)
8
 
9
+ def predict(text, max_length=64, temperature=0.7, do_sample=True):
10
  text = text.strip()
11
+ out_text = text2text_generator(text, max_length=max_length,
12
+ temperature=temperature,
13
+ do_sample=do_sample,
14
  eos_token_id = tokenizer.eos_token_id,
15
  bos_token_id = tokenizer.bos_token_id,
16
  pad_token_id = tokenizer.pad_token_id,
22
  return out_text
23
 
24
  iface = gr.Interface.load(
25
+ fn=predict,
26
+ inputs=[
27
+ gr.inputs.Textbox(lines=5, label="Input Text"),
28
+ gr.inputs.Slider(minimum=32, maximum=256, default=64, label="Max Length"),
29
+ gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.7, step=0.1, label="Temperature"),
30
+ gr.inputs.Checkbox(label="Do Sample"),
31
+ ],
32
+ outputs=gr.HTML(),
33
+ description="Galactica Base Model",
34
+ examples=[[
35
+ "The attention mechanism in LLM is",
36
+ 128,
37
+ 0.7,
38
+ True
39
+ ],
40
+ [
41
+ "Title: Attention is all you need\n\nAbstract:",
42
+ 128,
43
+ 0.7,
44
+ True
45
+ ]
46
+ ]
47
  )
48
 
49
  iface.launch()