zurd46 commited on
Commit
34c466e
·
verified ·
1 Parent(s): 36586bf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -0
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ from threading import Thread
3
+ import argparse
4
+ import torch
5
+ import gradio as gr
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer, AutoModelForCausalLM
7
+ from peft import PeftConfig, PeftModel
8
+ from utils import get_device # Angenommen, diese Funktion existiert bereits
9
+
10
+ # Create the parser
11
+ parser = argparse.ArgumentParser(description='Check model usage.')
12
+
13
+ # Add the arguments
14
+ parser.add_argument('--baseonly', action='store_true',
15
+ help='A boolean switch to indicate base only mode')
16
+
17
+ # Execute the parse_args() method
18
+ args = parser.parse_args()
19
+
20
+ # Define model and adapter names, data type, and quantization type
21
+ model_name = "microsoft/Phi-3-mini-4k-instruct"
22
+ adapters_name = "zurd46/eliAI"
23
+ torch_dtype = torch.bfloat16 # Set the appropriate torch data type
24
+
25
+ # Display device and CPU thread information
26
+ device = get_device()
27
+ print("Running on device:", device)
28
+ print("CPU threads:", torch.get_num_threads())
29
+
30
+ # Load tokenizer
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
32
+
33
+ # Load base model
34
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype)
35
+ model.resize_token_embeddings(len(tokenizer))
36
+
37
+ # Load adapter if available and not baseonly
38
+ usingAdapter = False
39
+ if not args.baseonly:
40
+ usingAdapter = True
41
+ model = PeftModel.from_pretrained(model, adapters_name)
42
+
43
+ model.to(device)
44
+
45
+ print(f"Model {model_name} loaded successfully on {device}")
46
+
47
+ # Function to run the text generation process
48
+ def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
49
+ template = "\n{}\n"
50
+ model_inputs = tokenizer(template.format(user_text) if usingAdapter else user_text, return_tensors="pt")
51
+ model_inputs = model_inputs.to(device)
52
+
53
+ # Generate text in a separate thread
54
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
55
+ generate_kwargs = dict(
56
+ input_ids=model_inputs['input_ids'],
57
+ streamer=streamer,
58
+ max_new_tokens=max_new_tokens,
59
+ do_sample=True,
60
+ top_p=top_p,
61
+ temperature=float(temperature),
62
+ top_k=top_k,
63
+ pad_token_id=tokenizer.pad_token_id,
64
+ eos_token_id=tokenizer.eos_token_id
65
+ )
66
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
67
+ t.start()
68
+
69
+ # Retrieve and yield the generated text
70
+ model_output = ""
71
+ for new_text in streamer:
72
+ model_output += new_text
73
+ return model_output
74
+
75
+ # Gradio UI setup
76
+ with gr.Blocks(css="""
77
+ div.svelte-sfqy0y {
78
+ display: flex;
79
+ flex-direction: inherit;
80
+ flex-wrap: wrap;
81
+ gap: var(--form-gap-width);
82
+ box-shadow: var(--block-shadow);
83
+ border: var(--block-border-width) solid var(--border-color-primary);
84
+ border-radius: var(--block-radius);
85
+ background: var(--block-background-fill);
86
+ overflow-y: hidden;
87
+ padding: 20px;
88
+ }
89
+ body {
90
+ font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
91
+ background-color: var(--body-background-fill);
92
+ color: #e0e0e0;
93
+ margin: 0;
94
+ padding: 0;
95
+ box-sizing: border-box;
96
+ }
97
+ .gradio-container {
98
+ max-width: 900px;
99
+ margin: auto;
100
+ padding: 20px;
101
+ border-radius: 8px;
102
+ box-shadow: 0 0 10px rgba(0,0,0,0.5);
103
+ background: var(--body-background-fill);
104
+ }
105
+ .gr-button {
106
+ background-color: var(--block-background-fill);
107
+ color: white;
108
+ border: none;
109
+ border-radius: 4px;
110
+ padding: 10px 24px;
111
+ cursor: pointer;
112
+ }
113
+ .gr-button:hover {
114
+ background-color: #3700b3;
115
+ }
116
+ .gr-slider input[type=range] {
117
+ -webkit-appearance: none;
118
+ width: 100%;
119
+ height: 8px;
120
+ border-radius: 5px;
121
+ background: #333;
122
+ outline: none;
123
+ opacity: 0.9;
124
+ -webkit-transition: .2s;
125
+ transition: opacity .2s;
126
+ }
127
+ .gr-slider input[type=range]:hover {
128
+ opacity: 1;
129
+ }
130
+ .gr-textbox {
131
+ background-color: var(--block-background-fill);
132
+ color: white;
133
+ border: none;
134
+ border-radius: 4px;
135
+ padding: 10px;
136
+ }
137
+ .chatbox {
138
+ max-height: 400px;
139
+ overflow-y: auto;
140
+ margin-bottom: 20px;
141
+ }
142
+ """) as demo:
143
+ gr.Markdown(
144
+ """
145
+ <div style="text-align: center; padding: 20px;">
146
+ <h1>🌙 eliAI Text Generation Interface</h1>
147
+ <h3>Model: Phi-3-mini-4k-instruct</h3>
148
+ <h4>Developed by Daniel Zurmühle</h4>
149
+ </div>
150
+ """)
151
+
152
+ with gr.Row():
153
+ with gr.Column(scale=3):
154
+ user_text = gr.Textbox(placeholder="Enter your question here", label="User Input", lines=3, elem_classes="gr-textbox")
155
+ button_submit = gr.Button(value="Submit", elem_classes="gr-button")
156
+
157
+ max_new_tokens = gr.Slider(minimum=1, maximum=1000, value=1000, step=1, label="Max New Tokens")
158
+ top_p = gr.Slider(minimum=0.05, maximum=1.0, value=0.95, step=0.05, label="Top-p (Nucleus Sampling)")
159
+ top_k = gr.Slider(minimum=1, maximum=50, value=50, step=1, label="Top-k")
160
+ temperature = gr.Slider(minimum=0.1, maximum=5.0, value=0.8, step=0.1, label="Temperature")
161
+
162
+ with gr.Column(scale=7):
163
+ model_output = gr.Chatbot(label="Chatbot Output", height=566)
164
+
165
+ def handle_submit(text, top_p, temperature, top_k, max_new_tokens):
166
+ response = run_generation(text, top_p, temperature, top_k, max_new_tokens)
167
+ return [(text, response)]
168
+
169
+ button_submit.click(handle_submit, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
170
+ user_text.submit(handle_submit, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
171
+
172
+ demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=7860)