younesbelkada commited on
Commit
5c7de89
1 Parent(s): 91bc5f0
Files changed (2) hide show
  1. app.py +49 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b3", use_cache=True)
5
+ tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-1b3")
6
+
7
+ def post_process_sentence(input_sentence, generated_sentence):
8
+ new_sentence = generated_sentence.replace(input_sentence, "")
9
+ if "\n" not in new_sentence:
10
+ return generated_sentence.replace(" ", " ") + "\n- "
11
+ else:
12
+ return (input_sentence + new_sentence.split("\n")[0]).replace(" ", " ") + "\n- "
13
+
14
+ def generate_single(model, tokenizer, input_sentence, max_length=50, top_k=0, temperature=0.7):
15
+ input_ids = tokenizer.encode(input_sentence, return_tensors="pt")
16
+ output = model.generate(
17
+ input_ids, do_sample=True,
18
+ max_length=len(input_sentence)+max_length,
19
+ top_k=top_k,
20
+ temperature=temperature
21
+ )
22
+ generated_sentence = tokenizer.decode(output[0], skip_special_tokens=True)
23
+ return post_process_sentence(input_sentence, generated_sentence)
24
+
25
+ def question_bloom(input_sentence, max_length, temperature):
26
+ post_processed_output = generate_single(model, tokenizer, input_sentence, temperature=temperature, max_length=max_length)
27
+ return post_processed_output.split("\n-")[-2]
28
+
29
+ gr.Interface(
30
+ question_bloom,
31
+ [
32
+ gr.Textbox(lines=10, label="Input code"),
33
+ gr.inputs.Slider(
34
+ minimum=8,
35
+ maximum=256,
36
+ step=1,
37
+ default=8,
38
+ label="Number of tokens to generate",
39
+ ),
40
+ gr.inputs.Slider(
41
+ minimum=0,
42
+ maximum=2,
43
+ step=0.1,
44
+ default=0.6,
45
+ label="Temperature",
46
+ ),
47
+ ],
48
+ outputs=gr.Textbox(label="Predicted sentence", lines=10),
49
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
1
+ transformers