aixsatoshi commited on
Commit
9a6b8ed
1 Parent(s): 23d16e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -5,7 +5,8 @@ import spaces
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
-
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL_ID = "TeamDelta/mistral-yuki-7B"
@@ -42,14 +43,19 @@ h3 {
42
  }
43
  """
44
 
45
-
46
  model = AutoModelForCausalLM.from_pretrained(
47
- MODEL_ID,
48
- torch_dtype=torch.float16,
49
- device_map="auto",
50
- )
51
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
52
 
 
 
 
 
 
53
  @spaces.GPU
54
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
55
  print(f'message is - {message}')
@@ -59,8 +65,6 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
59
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
60
  conversation.append({"role": "user", "content": message})
61
 
62
- #print(f"Conversation is -\n{conversation}")
63
-
64
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
65
  inputs = tokenizer(input_ids, return_tensors="pt").to(0)
66
 
@@ -75,7 +79,7 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
75
  max_new_tokens=max_new_tokens,
76
  do_sample=True,
77
  temperature=temperature,
78
- eos_token_id = [128001, 128009],
79
  )
80
 
81
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -86,8 +90,6 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
86
  buffer += new_text
87
  yield buffer
88
 
89
-
90
-
91
  chatbot = gr.Chatbot(height=500)
92
 
93
  with gr.Blocks(css=CSS) as demo:
@@ -144,16 +146,9 @@ with gr.Blocks(css=CSS) as demo:
144
  render=False,
145
  ),
146
  ],
147
- examples=[
148
- ["超能力を持つ主人公のSF物語のシナリオを考えてください。伏線の設定、テーマやログラインを理論的に使用してください"],
149
- ["子供の夏休みの自由研究のための、5つのアイデアと、その手法を簡潔に教えてください。"],
150
- ["パズルゲームのスクリプト作成のためにアドバイスお願いします"],
151
- ["マークダウン記法にて、ブロック崩しのゲーム作成の教科書作成してください"],
152
- ],
153
  cache_examples=False,
154
  )
155
 
156
-
157
-
158
  if __name__ == "__main__":
159
- demo.launch()
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
+ import random
9
+ from datasets import load_dataset
10
 
11
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
12
  MODEL_ID = "TeamDelta/mistral-yuki-7B"
 
43
  }
44
  """
45
 
46
+ # モデルとトークナイザーの読み込み
47
  model = AutoModelForCausalLM.from_pretrained(
48
+ MODEL_ID,
49
+ torch_dtype=torch.float16,
50
+ device_map="auto",
51
+ )
52
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
53
 
54
+ # データセットをロードして10個の例を取得
55
+ dataset = load_dataset("elyza/ELYZA-tasks-100")
56
+ examples = random.sample(dataset['train'], 10)
57
+ example_inputs = [example['input'] for example in examples]
58
+
59
  @spaces.GPU
60
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
61
  print(f'message is - {message}')
 
65
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
66
  conversation.append({"role": "user", "content": message})
67
 
 
 
68
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
69
  inputs = tokenizer(input_ids, return_tensors="pt").to(0)
70
 
 
79
  max_new_tokens=max_new_tokens,
80
  do_sample=True,
81
  temperature=temperature,
82
+ eos_token_id=[128001, 128009],
83
  )
84
 
85
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
 
90
  buffer += new_text
91
  yield buffer
92
 
 
 
93
  chatbot = gr.Chatbot(height=500)
94
 
95
  with gr.Blocks(css=CSS) as demo:
 
146
  render=False,
147
  ),
148
  ],
149
+ examples=example_inputs,
 
 
 
 
 
150
  cache_examples=False,
151
  )
152
 
 
 
153
  if __name__ == "__main__":
154
+ demo.launch()