Shriharsh commited on
Commit
6faaf41
β€’
1 Parent(s): 60a9a03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -90
app.py CHANGED
@@ -1,97 +1,55 @@
1
- from huggingface_hub import InferenceClient
2
  import gradio as gr
 
 
3
 
4
- client = InferenceClient(
5
- "meta-llama/Meta-Llama-3-8B-Instruct"
6
- )
7
 
8
- def format_prompt(message, history):
9
- prompt = "<s>"
10
- for user_prompt, bot_response in history:
11
- prompt += f"[INST] {user_prompt} [/INST]"
12
- prompt += f" {bot_response}</s> "
13
- prompt += f"[INST] {message} [/INST]"
14
- return prompt
15
 
16
- def generate(
17
- prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
18
- ):
19
- temperature = float(temperature)
20
- if temperature < 1e-2:
21
- temperature = 1e-2
22
- top_p = float(top_p)
23
 
24
- generate_kwargs = dict(
25
- temperature=temperature,
 
 
 
 
26
  max_new_tokens=max_new_tokens,
27
- top_p=top_p,
28
- repetition_penalty=repetition_penalty,
29
  do_sample=True,
30
- seed=42,
31
- )
32
-
33
- formatted_prompt = format_prompt(prompt, history)
34
-
35
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
36
- output = ""
37
-
38
- for response in stream:
39
- output += response.token.text
40
- yield output
41
- return output
42
-
43
-
44
- additional_inputs=[
45
- gr.Slider(
46
- label="Temperature",
47
- value=0.9,
48
- minimum=0.0,
49
- maximum=1.0,
50
- step=0.05,
51
- interactive=True,
52
- info="Higher values produce more diverse outputs",
53
- ),
54
- gr.Slider(
55
- label="Max new tokens",
56
- value=512,
57
- minimum=0,
58
- maximum=1048,
59
- step=64,
60
- interactive=True,
61
- info="The maximum numbers of new tokens",
62
- ),
63
- gr.Slider(
64
- label="Top-p (nucleus sampling)",
65
- value=0.90,
66
- minimum=0.0,
67
- maximum=1,
68
- step=0.05,
69
- interactive=True,
70
- info="Higher values sample more low-probability tokens",
71
- ),
72
- gr.Slider(
73
- label="Repetition penalty",
74
- value=1.2,
75
- minimum=1.0,
76
- maximum=2.0,
77
- step=0.05,
78
- interactive=True,
79
- info="Penalize repeated tokens",
80
- )
81
- ]
82
-
83
- # Create a Chatbot object with the desired height
84
- chatbot = gr.Chatbot(height=450,
85
- layout="bubble")
86
-
87
- with gr.Blocks() as demo:
88
- gr.HTML("<h1><center>πŸ€– Mistral-7B-Chat πŸ’¬<h1><center>")
89
- gr.ChatInterface(
90
- generate,
91
- chatbot=chatbot, # Use the created Chatbot object
92
- additional_inputs=additional_inputs,
93
- examples=[["Give me the code for Binary Search in C++"], ["Explain the chapter of The Grand Inquistor from The Brothers Karmazov"],["Why a penguin would make a great president? Give supporting arguments."]],
94
-
95
- )
96
-
97
- demo.queue().launch(debug=True)
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
6
+ device_map = 'auto'
 
7
 
8
+ def load_model() -> AutoModelForCausalLM:
9
+ return AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)
10
+
11
+ def load_tokenizer() -> AutoTokenizer:
12
+ return AutoTokenizer.from_pretrained(model_name)
 
 
13
 
14
+ def preprocess_messages(message: str, history: list, system_prompt: str) -> dict:
15
+ messages = [{'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': message}]
16
+ prompt = load_tokenizer().apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
17
+ return prompt
 
 
 
18
 
19
+ def generate_text(prompt: str, max_new_tokens: int, temperature: float) -> str:
20
+ model = load_model()
21
+ terminators = [load_tokenizer().eos_token_id, load_tokenizer().convert_tokens_to_ids(['\n'])]
22
+ temp = temperature + 0.1
23
+ outputs = model.generate(
24
+ prompt,
25
  max_new_tokens=max_new_tokens,
26
+ eos_token_id=terminators[0],
 
27
  do_sample=True,
28
+ temperature=temp,
29
+ top_p=0.9
30
+ )
31
+ return load_tokenizer().decode(outputs[0], skip_special_tokens=True)
32
+
33
+ def chat_function(
34
+ message: str,
35
+ history: list,
36
+ system_prompt: str,
37
+ max_new_tokens: int,
38
+ temperature: float
39
+ ) -> str:
40
+ prompt = preprocess_messages(message, history, system_prompt)
41
+ return generate_text(prompt, max_new_tokens, temperature)
42
+
43
+ gr.ChatInterface(
44
+ chat_function,
45
+ chatbot=gr.Chatbot(height=400),
46
+ textbox=gr.Textbox(placeholder="Enter message here", container=False, scale=7),
47
+ title="LLAMA3 Chat",
48
+ description="""Chat with llama3""",
49
+ theme="soft",
50
+ additional_inputs=[
51
+ gr.Textbox("You shall answer to all the questions as very smart AI", label="System Prompt"),
52
+ gr.Slider(512, 4096, label="Max New Tokens"),
53
+ gr.Slider(0, 1, label="Temperature")
54
+ ]
55
+ ).launch()