ColeGuion commited on
Commit
2b4398b
1 Parent(s): 960abb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -8
app.py CHANGED
@@ -1,11 +1,56 @@
 
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- def correct_text(text):
5
- model_id = "grammarly/coedit-large"
6
- pipe = pipeline("text2text-generation", model=model_id)
7
- corrected = pipe(text)[0]['generated_text']
8
- return corrected
9
 
10
- interface = gr.Interface(fn=correct_text, inputs="text_area", outputs="text")
11
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
  import gradio as gr
 
3
 
4
+ client = InferenceClient("grammarly/coedit-large")
 
 
 
 
5
 
6
+ def generate(prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
7
+ temperature = float(temperature)
8
+ if temperature < 1e-2: temperature = 1e-2
9
+ top_p = float(top_p)
10
+
11
+ generate_kwargs = dict(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42,)
12
+
13
+ #formatted_prompt = format_prompt_grammar(f"Corrected Sentence: {prompt}", history)
14
+ #formatted_prompt = format_prompt(f"{system_prompt} {prompt}", history)
15
+ formatted_prompt = "Fix grammatical errors in this sentence: " + prompt
16
+ print("\nPROMPT: \n\t" + formatted_prompt)
17
+
18
+ # Generate text from the HF inference
19
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
20
+ output = ""
21
+
22
+ for response in stream:
23
+ output += response.token.text
24
+ yield output
25
+ return output
26
+
27
+
28
+
29
+ additional_inputs=[
30
+ gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ),
31
+ gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ),
32
+ gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ),
33
+ gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", )
34
+ ]
35
+
36
+
37
+ gr.ChatInterface(
38
+ fn=generate,
39
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
40
+ additional_inputs=additional_inputs,
41
+ title="My Grammarly Space",
42
+ concurrency_limit=20,
43
+ ).launch(show_api=False)
44
+
45
+
46
+ #import gradio as gr
47
+ #from transformers import pipeline
48
+
49
+ #def correct_text(text):
50
+ # model_id = "grammarly/coedit-large"
51
+ # pipe = pipeline("text2text-generation", model=model_id)
52
+ # corrected = pipe(text)[0]['generated_text']
53
+ # return corrected
54
+
55
+ #interface = gr.Interface(fn=correct_text, inputs="text_area", outputs="text")
56
+ #interface.launch()