Leri777 commited on
Commit
900c5e2
·
verified ·
1 Parent(s): dd7ba44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -109
app.py CHANGED
@@ -1,116 +1,154 @@
1
  import os
2
- import logging
 
 
 
3
  import gradio as gr
4
- from huggingface_hub import InferenceClient
5
- from logging.handlers import RotatingFileHandler
6
-
7
- # Logging setup
8
- log_file = 'app_debug.log'
9
- logger = logging.getLogger(__name__)
10
- logger.setLevel(logging.DEBUG)
11
- file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5)
12
- file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
13
- logger.addHandler(file_handler)
14
- logger.debug("Application started")
15
-
16
- # Inference client setup
17
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
18
-
19
- def format_prompt(message, history):
20
- prompt = "<s>"
21
- for user_prompt, bot_response in history:
22
- prompt += f"[INST] {user_prompt} [/INST]"
23
- prompt += f" {bot_response}</s> "
24
- prompt += f"[INST] {message} [/INST]"
25
- return prompt
26
-
27
- def generate(
28
- prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ):
30
- logger.debug(f"Generating response for prompt: {prompt} with history: {history}")
31
- temperature = float(temperature)
32
- if temperature < 1e-2:
33
- temperature = 1e-2
34
- top_p = float(top_p)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  generate_kwargs = dict(
37
- temperature=temperature,
38
- max_new_tokens=max_new_tokens,
39
- top_p=top_p,
40
- repetition_penalty=repetition_penalty,
41
- do_sample=True,
42
- seed=42,
 
 
 
43
  )
44
 
45
- formatted_prompt = format_prompt(prompt, history)
46
-
47
- try:
48
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
49
- output = ""
50
-
51
- for response in stream:
52
- output += response.token.text
53
- yield output
54
- logger.debug(f"Generated response: {output}")
55
- return output
56
- except Exception as e:
57
- logger.exception("Error during text generation")
58
- return "An error occurred during response generation."
59
-
60
- def update_history(history, user_input, response):
61
- history.append((user_input, response))
62
- return history
63
-
64
- def chat_interface(user_input, history=[], temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
65
- logger.debug(f"User input: {user_input}")
66
- response = generate(user_input, history, temperature, max_new_tokens, top_p, repetition_penalty)
67
- history = update_history(history, user_input, response)
68
- logger.debug(f"Updated history: {history}")
69
- return response, history
70
-
71
- gr.ChatInterface(
72
- fn=chat_interface,
73
- chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=False, likeable=False, layout="panel"),
74
- additional_inputs=[
75
- gr.Slider(
76
- label="Temperature",
77
- value=0.9,
78
- minimum=0.0,
79
- maximum=1.0,
80
- step=0.05,
81
- interactive=True,
82
- info="Higher values produce more diverse outputs",
83
- ),
84
- gr.Slider(
85
- label="Max new tokens",
86
- value=256,
87
- minimum=0,
88
- maximum=1048,
89
- step=64,
90
- interactive=True,
91
- info="The maximum numbers of new tokens",
92
- ),
93
- gr.Slider(
94
- label="Top-p (nucleus sampling)",
95
- value=0.90,
96
- minimum=0.0,
97
- maximum=1,
98
- step=0.05,
99
- interactive=True,
100
- info="Higher values sample more low-probability tokens",
101
- ),
102
- gr.Slider(
103
- label="Repetition penalty",
104
- value=1.2,
105
- minimum=1.0,
106
- maximum=2.0,
107
- step=0.05,
108
- interactive=True,
109
- info="Penalize repeated tokens",
110
- )
111
- ],
112
- title="Mistral 7B v0.3",
113
- description=None
114
- ).launch(show_api=True)
115
-
116
- logger.debug("Chat interface initialized and launched")
 
 
 
1
  import os
2
+ import time
3
+ import spaces
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import gradio as gr
7
+ from threading import Thread
8
+
9
+ MODEL_LIST = ["mistralai/Mistral-Nemo-Instruct-2407"]
10
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL = os.environ.get("MODEL_ID")
12
+
13
+ TITLE = "<h1><center>Mistral-Nemo</center></h1>"
14
+
15
+ PLACEHOLDER = """
16
+ <center>
17
+ <p>The Mistral-Nemo is a pretrained generative text model of 12B parameters trained jointly by Mistral AI and NVIDIA.</p>
18
+ </center>
19
+ """
20
+
21
+
22
+ CSS = """
23
+ .duplicate-button {
24
+ margin: auto !important;
25
+ color: white !important;
26
+ background: black !important;
27
+ border-radius: 100vh !important;
28
+ }
29
+ h3 {
30
+ text-align: center;
31
+ }
32
+ """
33
+
34
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ MODEL,
39
+ torch_dtype=torch.bfloat16,
40
+ device_map="auto",
41
+ ignore_mismatched_sizes=True)
42
+
43
+ @spaces.GPU()
44
+ def stream_chat(
45
+ message: str,
46
+ history: list,
47
+ temperature: float = 0.3,
48
+ max_new_tokens: int = 1024,
49
+ top_p: float = 1.0,
50
+ top_k: int = 20,
51
+ penalty: float = 1.2,
52
  ):
53
+ print(f'message: {message}')
54
+ print(f'history: {history}')
 
 
 
55
 
56
+ conversation = []
57
+ for prompt, answer in history:
58
+ conversation.extend([
59
+ {"role": "user", "content": prompt},
60
+ {"role": "assistant", "content": answer},
61
+ ])
62
+
63
+ conversation.append({"role": "user", "content": message})
64
+
65
+ input_text=tokenizer.apply_chat_template(conversation, tokenize=False)
66
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
67
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
68
+
69
  generate_kwargs = dict(
70
+ input_ids=inputs,
71
+ max_new_tokens = max_new_tokens,
72
+ do_sample = False if temperature == 0 else True,
73
+ top_p = top_p,
74
+ top_k = top_k,
75
+ temperature = temperature,
76
+ streamer=streamer,
77
+ repetition_penalty=penalty,
78
+ pad_token_id = 10,
79
  )
80
 
81
+ with torch.no_grad():
82
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
83
+ thread.start()
84
+
85
+ buffer = ""
86
+ for new_text in streamer:
87
+ buffer += new_text
88
+ yield buffer
89
+
90
+
91
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
92
+
93
+ with gr.Blocks(css=CSS, theme="Nymbo/Nymbo_Theme") as demo:
94
+ gr.HTML(TITLE)
95
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
96
+ gr.ChatInterface(
97
+ fn=stream_chat,
98
+ chatbot=chatbot,
99
+ fill_height=True,
100
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
101
+ additional_inputs=[
102
+ gr.Slider(
103
+ minimum=0,
104
+ maximum=1,
105
+ step=0.1,
106
+ value=0.3,
107
+ label="Temperature",
108
+ render=False,
109
+ ),
110
+ gr.Slider(
111
+ minimum=128,
112
+ maximum=8192,
113
+ step=1,
114
+ value=1024,
115
+ label="Max new tokens",
116
+ render=False,
117
+ ),
118
+ gr.Slider(
119
+ minimum=0.0,
120
+ maximum=1.0,
121
+ step=0.1,
122
+ value=1.0,
123
+ label="top_p",
124
+ render=False,
125
+ ),
126
+ gr.Slider(
127
+ minimum=1,
128
+ maximum=20,
129
+ step=1,
130
+ value=20,
131
+ label="top_k",
132
+ render=False,
133
+ ),
134
+ gr.Slider(
135
+ minimum=0.0,
136
+ maximum=2.0,
137
+ step=0.1,
138
+ value=1.2,
139
+ label="Repetition penalty",
140
+ render=False,
141
+ ),
142
+ ],
143
+ examples=[
144
+ ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
145
+ ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
146
+ ["Tell me a random fun fact about the Roman Empire."],
147
+ ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
148
+ ],
149
+ cache_examples=False,
150
+ )
151
+
152
+
153
+ if __name__ == "__main__":
154
+ demo.launch()