LennardZuendorf commited on
Commit
7c4f7d6
1 Parent(s): c4b5a8c

feat: style updates, added license and data protection

Browse files
app.py CHANGED
@@ -1,7 +1,17 @@
1
  import gradio as gr
2
- import chatmodel as chat
3
  import interpret as shap
4
  import visualize as viz
 
 
 
 
 
 
 
 
 
 
5
 
6
  with gr.Blocks() as ui:
7
  with gr.Row():
@@ -17,47 +27,33 @@ with gr.Blocks() as ui:
17
  ### ChatBot Demo
18
  Mitral AI 7B Model fine-tuned for instruction and fully open source (see at [HGF](https://huggingface.co/mistralai/Mistral-7B-v0.1))
19
  """)
 
 
 
 
 
 
 
 
20
  with gr.Row():
21
- gr.ChatInterface(
22
- chat.interference
23
- )
24
  with gr.Row():
25
- gr.Slider(
26
- label="Temperature",
27
- value=0.7,
28
- minimum=0.0,
29
- maximum=1.0,
30
- step=0.05,
31
- interactive=True,
32
- info="Higher values produce more diverse outputs",
33
- ),
34
- gr.Slider(
35
- label="Max new tokens",
36
- value=256,
37
- minimum=0,
38
- maximum=1024,
39
- step=64,
40
- interactive=True,
41
- info="The maximum numbers of new tokens",
42
- ),
43
- gr.Slider(
44
- label="Top-p (nucleus sampling)",
45
- value=0.95,
46
- minimum=0.0,
47
- maximum=1,
48
- step=0.05,
49
- interactive=True,
50
- info="Higher values sample more low-probability tokens",
51
- ),
52
- gr.Slider(
53
- label="Repetition penalty",
54
- value=1.1,
55
- minimum=1.0,
56
- maximum=2.0,
57
- step=0.05,
58
- interactive=True,
59
- info="Penalize repeated tokens",
60
- )
61
 
62
  with gr.Tab("SHAP Dashboard"):
63
  with gr.Row():
@@ -83,6 +79,9 @@ with gr.Blocks() as ui:
83
  Adopted from official [model paper](https://arxiv.org/abs/2310.06825) by Mistral AI
84
  """)
85
 
 
 
 
86
 
87
  if __name__ == "__main__":
88
  ui.launch(debug=True)
 
1
  import gradio as gr
2
+ import chatmodel as model
3
  import interpret as shap
4
  import visualize as viz
5
+ import markdown
6
+
7
+ def load_md(filename):
8
+ path = "./public/"+str(filename)
9
+
10
+ # credit: official python-markdown documentation (https://python-markdown.github.io/reference/)
11
+ with open(path, "r") as file:
12
+ text = file.read()
13
+
14
+ return markdown.markdown(text)
15
 
16
  with gr.Blocks() as ui:
17
  with gr.Row():
 
27
  ### ChatBot Demo
28
  Mitral AI 7B Model fine-tuned for instruction and fully open source (see at [HGF](https://huggingface.co/mistralai/Mistral-7B-v0.1))
29
  """)
30
+
31
+ with gr.Row():
32
+ chatbot = gr.Chatbot(layout="panel", show_copy_button=True,avatar_images=("./public/human.jpg","./public/bot.jpg"))
33
+ with gr.Row():
34
+ gr.Markdown(
35
+ """
36
+ ##### ⚠️ All Conversations are recorded for qa assurance and explanation functionality!
37
+ """)
38
  with gr.Row():
39
+ prompt = gr.Textbox(label="Input Message")
 
 
40
  with gr.Row():
41
+ with gr.Column(scale=1):
42
+ clear_btn = gr.ClearButton([prompt, chatbot])
43
+ with gr.Column(scale=1):
44
+ submit_btn = gr.Button("Submit")
45
+
46
+ submit_btn.click(model.chat, [prompt, chatbot], [prompt, chatbot])
47
+ prompt.submit(model.chat, [prompt, chatbot], [prompt, chatbot])
48
+
49
+ with gr.Tab("Explanations"):
50
+ with gr.Row():
51
+ gr.Markdown(
52
+ """
53
+ ### Get Explanations for
54
+ SHAP Visualization Dashboard adopted from [shapash](https://github.com/MAIF/shapash)
55
+ """)
56
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  with gr.Tab("SHAP Dashboard"):
59
  with gr.Row():
 
79
  Adopted from official [model paper](https://arxiv.org/abs/2310.06825) by Mistral AI
80
  """)
81
 
82
+ with gr.Row():
83
+ with gr.Accordion("Credits, Data Protection and License", open=False):
84
+ gr.Markdown(value=load_md("credits_dataprotection_license.md"))
85
 
86
  if __name__ == "__main__":
87
  ui.launch(debug=True)
chatmodel.py CHANGED
@@ -4,10 +4,24 @@ import gradio as gr
4
 
5
  token = os.environ.get("HGFTOKEN")
6
 
7
- client = InferenceClient(
8
  "mistralai/Mistral-7B-Instruct-v0.1"
9
  )
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def format_prompt(message, history):
12
  prompt = "<s>"
13
  for user_prompt, bot_response in history:
@@ -16,33 +30,20 @@ def format_prompt(message, history):
16
  prompt += f"[INST] {message} [/INST]"
17
  return prompt
18
 
19
- def interference(
20
- prompt, history, temperature=0.7, max_new_tokens=256, top_p=0.95, repetition_penalty=1.1,
21
- ):
22
- temperature = float(temperature)
23
  if temperature < 1e-2:
24
  temperature = 1e-2
25
- top_p = float(top_p)
26
 
27
  generate_kwargs = dict(
28
  temperature=temperature,
29
- max_new_tokens=max_new_tokens,
30
  top_p=top_p,
31
- repetition_penalty=repetition_penalty,
32
  do_sample=True,
33
  seed=42,
34
  )
35
 
36
- formatted_prompt = format_prompt(prompt, history)
37
-
38
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
39
- output = ""
40
-
41
- for response in stream:
42
- output += response.token.text
43
- yield output
44
- return output
45
-
46
- custom=[
47
-
48
- ]
 
4
 
5
  token = os.environ.get("HGFTOKEN")
6
 
7
+ interference = InferenceClient(
8
  "mistralai/Mistral-7B-Instruct-v0.1"
9
  )
10
 
11
+ model_temperature = 0.7
12
+ model_max_new_tokens = 256
13
+ model_top_p = 0.95
14
+ model_repetition_penalty = 1.1
15
+
16
+ def chat (prompt, history,):
17
+
18
+ formatted_prompt = format_prompt(prompt, history)
19
+ answer=respond(formatted_prompt)
20
+
21
+ history.append((prompt, answer))
22
+
23
+ return "",history
24
+
25
  def format_prompt(message, history):
26
  prompt = "<s>"
27
  for user_prompt, bot_response in history:
 
30
  prompt += f"[INST] {message} [/INST]"
31
  return prompt
32
 
33
+ def respond(formatted_prompt):
34
+ temperature = float(model_temperature)
 
 
35
  if temperature < 1e-2:
36
  temperature = 1e-2
37
+ top_p = float(model_top_p)
38
 
39
  generate_kwargs = dict(
40
  temperature=temperature,
41
+ max_new_tokens=model_max_new_tokens,
42
  top_p=top_p,
43
+ repetition_penalty=model_repetition_penalty,
44
  do_sample=True,
45
  seed=42,
46
  )
47
 
48
+ output = interference.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=True, return_full_text=False).generated_text
49
+ return output
 
 
 
 
 
 
 
 
 
 
 
public/bot.jpg ADDED
public/credits_dataprotection_license.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ### Credits
2
+
3
+ ### Data Protection
4
+
5
+ ### License
6
+ This Product is licensed under the MIT license.
public/human.jpg ADDED
requirements.txt CHANGED
@@ -2,4 +2,5 @@ gradio
2
  transformers
3
  torch
4
  shap
5
- accelerate
 
 
2
  transformers
3
  torch
4
  shap
5
+ accelerate
6
+ markdown