Staticaliza commited on
Commit
95167e7
1 Parent(s): e9bf215

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import Repository, InferenceClient
3
+ import os
4
+ import json
5
+
6
+ API_TOKEN = os.environ.get("API_TOKEN")
7
+ API_ENDPOINT = os.environ.get("API_ENDPOINT")
8
+
9
+ KEY = os.environ.get("KEY")
10
+
11
+ API_ENDPOINTS = {
12
+ "Falcon": "tiiuae/falcon-180B-chat",
13
+ "Llama": "meta-llama/Llama-2-70b-chat-hf"
14
+ }
15
+
16
+ CHOICES = []
17
+ CLIENTS = {}
18
+
19
+ for model_name, model_endpoint in API_ENDPOINTS.items():
20
+ CHOICES.append(model_name)
21
+ CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })
22
+
23
+ def predict(input, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):
24
+
25
+ if (access_key != KEY):
26
+ print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
27
+ return ("[UNAUTHORIZED ACCESS]", input);
28
+
29
+ stops = json.loads(stop_seqs)
30
+
31
+ response = CLIENTS[model].text_generation(
32
+ input,
33
+ temperature = temperature,
34
+ max_new_tokens = max_tokens,
35
+ top_p = top_p,
36
+ top_k = top_k,
37
+ repetition_penalty = rep_p,
38
+ stop_sequences = stops,
39
+ do_sample = True,
40
+ seed = seed,
41
+ stream = False,
42
+ details = False,
43
+ return_full_text = False
44
+ )
45
+
46
+ print(f"---\nUSER: {input}\nBOT: {response}\n---")
47
+
48
+ return (response, input)
49
+
50
+ def maintain_cloud():
51
+ print(">>> SPACE MAINTAINED!")
52
+ return ("SUCCESS!", "SUCCESS!")
53
+
54
+ with gr.Blocks() as demo:
55
+ with gr.Row(variant = "panel"):
56
+ gr.Markdown("⚛️ This is a private LLM Space owned within STC Holdings!\n\n\nhttps://discord.gg/6JRtGawz7B")
57
+
58
+ with gr.Row():
59
+ with gr.Column():
60
+ input = gr.Textbox(label = "Input", lines = 4)
61
+ access_key = gr.Textbox(label = "Access Key", lines = 1)
62
+ run = gr.Button("▶")
63
+ cloud = gr.Button("☁️")
64
+
65
+ with gr.Column():
66
+ model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
67
+ temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" )
68
+ top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" )
69
+ top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
70
+ rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
71
+ max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
72
+ stop_seqs = gr.Textbox(label = "Stop Sequences ( JSON Array / 4 Max )", lines = 1, value = '["‹", "›"]')
73
+ seed = gr.Slider( minimum = 0, maximum = 8192, value = 42, step = 1, interactive = True, label = "Seed" )
74
+
75
+ with gr.Row():
76
+ with gr.Column():
77
+ output = gr.Textbox(label = "Output", value = "", lines = 50)
78
+
79
+ run.click(predict, inputs = [input, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input])
80
+ cloud.click(maintain_cloud, inputs = [], outputs = [input, output])
81
+
82
+ demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True)