LennardZuendorf commited on
Commit
c4b5a8c
1 Parent(s): d19acd3

feat: switched model to Mistral AI 7B

Browse files
Files changed (2) hide show
  1. app.py +45 -6
  2. chatmodel.py +37 -50
app.py CHANGED
@@ -10,15 +10,54 @@ with gr.Blocks() as ui:
10
  # Thesis Demo - AI Chat Application with XAI
11
  ### Select between tabs below for the different views.
12
  """)
13
- with gr.Tab("LlaMa 2 ChatBot"):
14
  with gr.Row():
15
  gr.Markdown(
16
  """
17
  ### ChatBot Demo
18
- LlaMa 2 7B Model fine-tuned for chat and transformed to huggingface format (see at [HGF](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf))
19
  """)
20
  with gr.Row():
21
- gr.ChatInterface(chat.interference)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  with gr.Tab("SHAP Dashboard"):
24
  with gr.Row():
@@ -36,12 +75,12 @@ with gr.Blocks() as ui:
36
  Visualization Dashboard adopted from [BERTViz](https://github.com/jessevig/bertviz)
37
  """)
38
 
39
- with gr.Tab("LlaMa 2 Model Overview"):
40
  with gr.Row():
41
  gr.Markdown(
42
  """
43
- ### LlaMa 2 Model & Data Overview for Transparency
44
- Adopted from official [model paper](https://arxiv.org/abs/2307.09288) by Meta AI
45
  """)
46
 
47
 
 
10
  # Thesis Demo - AI Chat Application with XAI
11
  ### Select between tabs below for the different views.
12
  """)
13
+ with gr.Tab("Mistral AI ChatBot"):
14
  with gr.Row():
15
  gr.Markdown(
16
  """
17
  ### ChatBot Demo
18
+ Mitral AI 7B Model fine-tuned for instruction and fully open source (see at [HGF](https://huggingface.co/mistralai/Mistral-7B-v0.1))
19
  """)
20
  with gr.Row():
21
+ gr.ChatInterface(
22
+ chat.interference
23
+ )
24
+ with gr.Row():
25
+ gr.Slider(
26
+ label="Temperature",
27
+ value=0.7,
28
+ minimum=0.0,
29
+ maximum=1.0,
30
+ step=0.05,
31
+ interactive=True,
32
+ info="Higher values produce more diverse outputs",
33
+ ),
34
+ gr.Slider(
35
+ label="Max new tokens",
36
+ value=256,
37
+ minimum=0,
38
+ maximum=1024,
39
+ step=64,
40
+ interactive=True,
41
+ info="The maximum numbers of new tokens",
42
+ ),
43
+ gr.Slider(
44
+ label="Top-p (nucleus sampling)",
45
+ value=0.95,
46
+ minimum=0.0,
47
+ maximum=1,
48
+ step=0.05,
49
+ interactive=True,
50
+ info="Higher values sample more low-probability tokens",
51
+ ),
52
+ gr.Slider(
53
+ label="Repetition penalty",
54
+ value=1.1,
55
+ minimum=1.0,
56
+ maximum=2.0,
57
+ step=0.05,
58
+ interactive=True,
59
+ info="Penalize repeated tokens",
60
+ )
61
 
62
  with gr.Tab("SHAP Dashboard"):
63
  with gr.Row():
 
75
  Visualization Dashboard adopted from [BERTViz](https://github.com/jessevig/bertviz)
76
  """)
77
 
78
+ with gr.Tab("Mitral Model Overview"):
79
  with gr.Row():
80
  gr.Markdown(
81
  """
82
+ ### Mistral 7B Model & Data Overview for Transparency
83
+ Adopted from official [model paper](https://arxiv.org/abs/2310.06825) by Mistral AI
84
  """)
85
 
86
 
chatmodel.py CHANGED
@@ -1,61 +1,48 @@
1
- from transformers import pipeline
2
- import torch
3
- from transformers import AutoTokenizer
4
  import os
 
5
 
6
  token = os.environ.get("HGFTOKEN")
7
 
8
- model = "meta-llama/Llama-2-7b-chat-hf"
9
- tokenizer = AutoTokenizer.from_pretrained(model, token=token)
10
-
11
- llama_pipeline = pipeline(
12
- "text-generation",
13
- model=model,
14
- torch_dtype=torch.float32,
15
- device_map="auto",
16
- token = token
17
  )
18
 
19
- # Formatting function for message and history
20
- def format_message(message: str, history: list, system_prompt:str, memory_limit: int = 3) -> str:
21
-
22
- if len(history) > memory_limit:
23
- history = history[-memory_limit:]
24
-
25
- system_prompt="<s>[INST] <<SYS>>\n"+system_prompt+"\n<</SYS>>"
26
-
27
- if len(history) == 0:
28
- return system_prompt + f"{message} [/INST]"
29
-
30
- formatted_message = system_prompt + f"{history[0][0]} [/INST] {history[0][1]} </s>"
31
-
32
- # Handle conversation history
33
- for user_msg, model_answer in history[1:]:
34
- formatted_message += f"<s>[INST] {user_msg} [/INST] {model_answer} </s>"
35
-
36
- # Handle the current message
37
- formatted_message += f"<s>[INST] {message} [/INST]"
38
-
39
- return formatted_message
 
 
 
40
 
41
- # Generate a response from the Llama model
42
- def interference(message: str, history: list, ) -> str:
43
- system_prompt="You are a helpful assistant providing reasonable answers."
44
 
45
- query = format_message(message, history, system_prompt)
46
- response = ""
47
 
48
- sequences = llama_pipeline(
49
- query,
50
- do_sample=True,
51
- top_k=10,
52
- num_return_sequences=1,
53
- eos_token_id=tokenizer.eos_token_id,
54
- max_length=1024,
55
- )
56
 
57
- generated_text = sequences[0]['generated_text']
58
- response = generated_text[len(query):] # Remove the prompt from the output
59
 
60
- print("Chatbot:", response.strip())
61
- return response.strip()
 
1
+ from huggingface_hub import InferenceClient
 
 
2
  import os
3
+ import gradio as gr
4
 
5
  token = os.environ.get("HGFTOKEN")
6
 
7
+ client = InferenceClient(
8
+ "mistralai/Mistral-7B-Instruct-v0.1"
 
 
 
 
 
 
 
9
  )
10
 
11
+ def format_prompt(message, history):
12
+ prompt = "<s>"
13
+ for user_prompt, bot_response in history:
14
+ prompt += f"[INST] {user_prompt} [/INST]"
15
+ prompt += f" {bot_response}</s> "
16
+ prompt += f"[INST] {message} [/INST]"
17
+ return prompt
18
+
19
+ def interference(
20
+ prompt, history, temperature=0.7, max_new_tokens=256, top_p=0.95, repetition_penalty=1.1,
21
+ ):
22
+ temperature = float(temperature)
23
+ if temperature < 1e-2:
24
+ temperature = 1e-2
25
+ top_p = float(top_p)
26
+
27
+ generate_kwargs = dict(
28
+ temperature=temperature,
29
+ max_new_tokens=max_new_tokens,
30
+ top_p=top_p,
31
+ repetition_penalty=repetition_penalty,
32
+ do_sample=True,
33
+ seed=42,
34
+ )
35
 
36
+ formatted_prompt = format_prompt(prompt, history)
 
 
37
 
38
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
39
+ output = ""
40
 
41
+ for response in stream:
42
+ output += response.token.text
43
+ yield output
44
+ return output
 
 
 
 
45
 
46
+ custom=[
 
47
 
48
+ ]