richardr1126 commited on
Commit
40c895f
0 Parent(s):

Duplicate from richardr1126/natsql-wizardcoder-demo

Browse files
Files changed (5) hide show
  1. .gitattributes +34 -0
  2. .gitignore +4 -0
  3. README.md +13 -0
  4. app.py +92 -0
  5. requirements.txt +10 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ venv/
2
+ .venv/
3
+ env/
4
+ .env/
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: NatSQL WizardCoder Demo
3
+ emoji: 🧙‍♂️
4
+ colorFrom: gray
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.37.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: richardr1126/natsql-wizardcoder-demo
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Event, Thread
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoTokenizer,
6
+ StoppingCriteria,
7
+ StoppingCriteriaList,
8
+ TextIteratorStreamer,
9
+ )
10
+ from huggingface_hub import login
11
+ import gradio as gr
12
+ import torch
13
+
14
+ login(os.getenv("HF_TOKEN", None))
15
+
16
+ model_name = "richardr1126/spider-natsql-wizard-coder-8bit"
17
+ tok = AutoTokenizer.from_pretrained(model_name)
18
+
19
+ max_new_tokens = 1536
20
+
21
+ print(f"Starting to load the model {model_name}")
22
+
23
+ m = AutoModelForCausalLM.from_pretrained(
24
+ model_name,
25
+ device_map=0,
26
+ load_in_8bit=True,
27
+ )
28
+
29
+ m.config.pad_token_id = m.config.eos_token_id
30
+ m.generation_config.pad_token_id = m.config.eos_token_id
31
+
32
+ stop_tokens = [";", "###", "Result"]
33
+ stop_token_ids = tok.convert_tokens_to_ids(stop_tokens)
34
+
35
+ print(f"Successfully loaded the model {model_name} into memory")
36
+
37
+ class StopOnTokens(StoppingCriteria):
38
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
39
+ for stop_id in stop_token_ids:
40
+ if input_ids[0][-1] == stop_id:
41
+ return True
42
+ return False
43
+
44
+ def bot(input_message: str, temperature=0.1, top_p=0.9, top_k=0, repetition_penalty=1.08):
45
+ stop = StopOnTokens()
46
+
47
+ messages = input_message
48
+
49
+ input_ids = tok(messages, return_tensors="pt").input_ids
50
+ input_ids = input_ids.to(m.device)
51
+ streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
52
+ generate_kwargs = dict(
53
+ input_ids=input_ids,
54
+ max_new_tokens=max_new_tokens,
55
+ temperature=temperature,
56
+ do_sample=temperature > 0.0,
57
+ top_p=top_p,
58
+ top_k=top_k,
59
+ repetition_penalty=repetition_penalty,
60
+ streamer=streamer,
61
+ stopping_criteria=StoppingCriteriaList([stop]),
62
+ )
63
+
64
+ stream_complete = Event()
65
+
66
+ def generate_and_signal_complete():
67
+ m.generate(**generate_kwargs)
68
+ stream_complete.set()
69
+
70
+ t1 = Thread(target=generate_and_signal_complete)
71
+ t1.start()
72
+
73
+ partial_text = ""
74
+ for new_text in streamer:
75
+ partial_text += new_text
76
+
77
+ return partial_text
78
+
79
+ gradio_interface = gr.Interface(
80
+ fn=bot,
81
+ inputs=[
82
+ "text",
83
+ gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.1, step=0.1),
84
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.0, maximum=1.0, value=0.9, step=0.01),
85
+ gr.Slider(label="Top-k", minimum=0, maximum=200, value=0, step=1),
86
+ gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.1)
87
+ ],
88
+ outputs="text",
89
+ title="REST API with Gradio and Huggingface Spaces",
90
+ description="This is a demo of how to build an AI powered REST API with Gradio and Huggingface Spaces – for free! See the **Use via API** link at the bottom of this page.",
91
+ )
92
+ gradio_interface.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ gradio
3
+ torch
4
+ numpy
5
+ sentencepiece
6
+ bitsandbytes
7
+ scipy
8
+ transformers
9
+ accelerate
10
+ huggingface_hub