Locutusque commited on
Commit
91ae465
1 Parent(s): 4f715a1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import torch
4
+ import subprocess
5
+ import spaces
6
+
7
+
8
+ @spaces.GPU
9
+ def _build_flash_attn():
10
+ subprocess.check_call("pip install flash-attn", shell=True)
11
+ _build_flash_attn() # This is how we'll build flash-attn.
12
+ # Initialize the model pipeline
13
+ generator = pipeline('text-generation', model='mistralai/Mistral-7B-v0.1', torch_dtype=torch.bfloat16, use_flash_attention_2=True)
14
+ @spaces.GPU
15
+ def generate_text(prompt, temperature, top_p, top_k, repetition_penalty, max_length):
16
+ # Generate text using the model
17
+ generator.model.cuda()
18
+ outputs = generator(
19
+ prompt,
20
+ max_new_tokens=max_length,
21
+ temperature=temperature,
22
+ top_p=top_p,
23
+ top_k=top_k,
24
+ repetition_penalty=repetition_penalty,
25
+ return_full_text=False
26
+ )
27
+ # Extract the generated text and return it
28
+ generated_text = outputs[0]['generated_text']
29
+ return generated_text
30
+ # Create the Gradio interface
31
+ iface = gr.Interface(
32
+ fn=generate_text,
33
+ inputs=[
34
+ gr.inputs.Textbox(label="Prompt", lines=2, placeholder="Type a prompt..."),
35
+ gr.inputs.Slider(minimum=0.1, maximum=2.0, step=0.01, default=0.8, label="Temperature"),
36
+ gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.95, label="Top p"),
37
+ gr.inputs.Slider(minimum=0, maximum=100, step=1, default=40, label="Top k"),
38
+ gr.inputs.Slider(minimum=1.0, maximum=2.0, step=0.01, default=1.10, label="Repetition Penalty"),
39
+ gr.inputs.Slider(minimum=5, maximum=4096, step=5, default=1024, label="Max Length")
40
+ ],
41
+ outputs=gr.outputs.Textbox(label="Generated Text"),
42
+ title="Text Completion Model",
43
+ description="Try out the Mistral-7B model for free! Note this is the pretrained model and is not fine-tuned for instruction."
44
+ )
45
+
46
+ iface.launch()