leonardlin commited on
Commit
554f3ed
1 Parent(s): 36badd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -2
app.py CHANGED
@@ -4,7 +4,7 @@ from huggingface_hub import InferenceClient
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
9
 
10
  def respond(
@@ -60,4 +60,109 @@ demo = gr.ChatInterface(
60
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
+ client = InferenceClient("shisa-ai/shisa-llama3-8b-v1")
8
 
9
 
10
  def respond(
 
60
 
61
 
62
  if __name__ == "__main__":
63
+ demo.launch()
64
+
65
+ '''
66
+ # https://www.gradio.app/guides/using-hugging-face-integrations
67
+
68
+ import gradio as gr
69
+ import logging
70
+ import html
71
+ from pprint import pprint
72
+ import time
73
+ import torch
74
+ from threading import Thread
75
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
76
+
77
+ # Model
78
+ model_name = "augmxnt/shisa-7b-v1"
79
+
80
+ # UI Settings
81
+ title = "Shisa 7B"
82
+ description = "Test out <a href='https://huggingface.co/augmxnt/shisa-7b-v1'>Shisa 7B</a> in either English or Japanese. If you aren't getting the right language outputs, you can try changing the system prompt to the appropriate language.\n\nNote: we are running this model quantized at `load_in_4bit` to fit in 16GB of VRAM."
83
+ placeholder = "Type Here / ここに入力してください"
84
+ examples = [
85
+ ["What are the best slices of pizza in New York City?"],
86
+ ["東京でおすすめのラーメン屋ってどこ?"],
87
+ ['How do I program a simple "hello world" in Python?'],
88
+ ["Pythonでシンプルな「ハローワールド」をプログラムするにはどうすればいいですか?"],
89
+ ]
90
+
91
+ # LLM Settings
92
+ # Initial
93
+ system_prompt = 'You are a helpful, bilingual assistant. Reply in same language as the user.'
94
+ default_prompt = system_prompt
95
+
96
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
97
+ model = AutoModelForCausalLM.from_pretrained(
98
+ model_name,
99
+ torch_dtype=torch.bfloat16,
100
+ device_map="auto",
101
+ # load_in_8bit=True,
102
+ load_in_4bit=True,
103
+ use_flash_attention_2=True,
104
+ )
105
+
106
+ def chat(message, history, system_prompt):
107
+ if not system_prompt:
108
+ system_prompt = default_prompt
109
+
110
+ print('---')
111
+ print('Prompt:', system_prompt)
112
+ pprint(history)
113
+ print(message)
114
+
115
+ # Let's just rebuild every time it's easier
116
+ chat_history = [{"role": "system", "content": system_prompt}]
117
+ for h in history:
118
+ chat_history.append({"role": "user", "content": h[0]})
119
+ chat_history.append({"role": "assistant", "content": h[1]})
120
+ chat_history.append({"role": "user", "content": message})
121
+
122
+ input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt")
123
+
124
+ # for multi-gpu, find the device of the first parameter of the model
125
+ first_param_device = next(model.parameters()).device
126
+ input_ids = input_ids.to(first_param_device)
127
+
128
+ generate_kwargs = dict(
129
+ inputs=input_ids,
130
+ max_new_tokens=200,
131
+ do_sample=True,
132
+ temperature=0.7,
133
+ repetition_penalty=1.15,
134
+ top_p=0.95,
135
+ eos_token_id=tokenizer.eos_token_id,
136
+ pad_token_id=tokenizer.eos_token_id,
137
+ )
138
+
139
+ output_ids = model.generate(**generate_kwargs)
140
+ new_tokens = output_ids[0, input_ids.size(1):]
141
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True)
142
+ return response
143
+
144
+
145
+ chat_interface = gr.ChatInterface(
146
+ chat,
147
+ chatbot=gr.Chatbot(height=400),
148
+ textbox=gr.Textbox(placeholder=placeholder, container=False, scale=7),
149
+ title=title,
150
+ description=description,
151
+ theme="soft",
152
+ examples=examples,
153
+ cache_examples=False,
154
+ undo_btn="Delete Previous",
155
+ clear_btn="Clear",
156
+ additional_inputs=[
157
+ gr.Textbox(system_prompt, label="System Prompt (Change the language of the prompt for better replies)"),
158
+ ],
159
+ )
160
+
161
+ # https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise
162
+ with gr.Blocks() as demo:
163
+ chat_interface.render()
164
+ gr.Markdown("You can try asking this question in Japanese or English. We limit output to 200 tokens.")
165
+
166
+ demo.queue().launch()
167
+
168
+ '''