ColeGuion commited on
Commit
9bd7774
·
verified ·
1 Parent(s): 9797188

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -16
app.py CHANGED
@@ -1,17 +1,40 @@
 
1
  import gradio as gr
2
- from transformers import pipeline
3
-
4
- def correct_text(text, temperature=1.0, top_p=0.9, top_k=50):
5
- model_id = "grammarly/coedit-large"
6
- pipe = pipeline("text2text-generation", model=model_id,
7
- generation_kwargs={"temperature": temperature, "top_p": top_p, "top_k": top_k})
8
- corrected = pipe(text)[0]['generated_text']
9
- return corrected
10
-
11
- interface = gr.Interface(fn=correct_text,
12
- inputs=[gr.inputs.TextArea(label="Input Text"),
13
- gr.inputs.Slider(minimum=0.1, maximum=1.0, default=1.0, label="Temperature"),
14
- gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.9, label="Top_p"),
15
- gr.inputs.Slider(minimum=0, maximum=100, default=50, label="Top_k")],
16
- outputs="text")
17
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
  import gradio as gr
3
+
4
+ client = InferenceClient("grammarly/coedit-large")
5
+
6
+ def generate(prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
7
+ temperature = float(temperature)
8
+ if temperature < 1e-2: temperature = 1e-2
9
+ top_p = float(top_p)
10
+
11
+ generate_kwargs = dict(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k) # seed=42,)
12
+
13
+ formatted_prompt = "Fix grammatical errors in this sentence: " + prompt
14
+ print("\nPROMPT: \n\t" + formatted_prompt)
15
+
16
+ # Generate text from the HF inference
17
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
18
+ output = ""
19
+
20
+ for response in stream:
21
+ output += response.token.text
22
+ yield output
23
+ return output
24
+
25
+
26
+
27
+ additional_inputs=[
28
+ gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ),
29
+ gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ),
30
+ gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ),
31
+ gr.Slider( label="Top-k", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more TOP-K", )
32
+ ]
33
+
34
+ gr.ChatInterface(
35
+ fn=generate,
36
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
37
+ additional_inputs=additional_inputs,
38
+ title="My Grammarly Space",
39
+ concurrency_limit=20,
40
+ ).launch(show_api=False)