lamhieu commited on
Commit
ef88752
ยท
1 Parent(s): 7a736e5

chore: support tools with search on internet

Browse files
Files changed (3) hide show
  1. README.md +2 -4
  2. app.py +247 -56
  3. requirements.txt +3 -1
README.md CHANGED
@@ -31,11 +31,9 @@ tags:
31
 
32
  ### Notes
33
 
34
- The extension source code belongs to: "LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning".
35
 
36
- See source code details [here](https://github.com/datamllab/LongLM).
37
-
38
- ```
39
  @misc{jin2024llm,
40
  title={LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning},
41
  author={Hongye Jin and Xiaotian Han and Jingfeng Yang and Zhimeng Jiang and Zirui Liu and Chia-Yuan Chang and Huiyuan Chen and Xia Hu},
 
31
 
32
  ### Notes
33
 
34
+ The extension source code belongs to: "LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning". See source code details [here](https://github.com/datamllab/LongLM).
35
 
36
+ ```tex
 
 
37
  @misc{jin2024llm,
38
  title={LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning},
39
  author={Hongye Jin and Xiaotian Han and Jingfeng Yang and Zhimeng Jiang and Zirui Liu and Chia-Yuan Chang and Huiyuan Chen and Xia Hu},
app.py CHANGED
@@ -1,6 +1,8 @@
1
  # pylint: skip-file
2
 
3
  import subprocess
 
 
4
 
5
  subprocess.run(
6
  f"pip install flash-attn --no-build-isolation",
@@ -15,24 +17,27 @@ from typing import Iterator
15
  import gradio as gr
16
  import spaces
17
  import torch
 
 
18
  import SelfExtend
19
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
20
 
21
 
22
- MAX_MAX_NEW_TOKENS = 4096
23
- DEFAULT_MAX_NEW_TOKENS = 1536
24
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "123392"))
25
 
26
  DESCRIPTION = """\
27
- # Playground with Ghost 8B Beta (ฮฒ, 128k)
28
 
29
- **Ghost 8B Beta** is a large language model developed with goals that include excellent multilingual support, superior knowledge capabilities, and cost-effectiveness. The model comes in two context length versions, [8k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-8k) and [128k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-128k), along with multilingual function tools support by default.
30
-
31
- The Ghost 8B Beta model outperforms prominent models such as Llama 3 8B Instruct, GPT 3.5 Turbo in the lc_winrate score. In addition, it also outperforms Claude 3 Opus, Claude 3 Sonnet, GPT-4, and Mistral Large when comparing the winrate score of AlpacaEval 2.0, [*](https://ghost-x.org/docs/models/ghost-8b-beta/).
32
 
33
  The languages supported are ๐Ÿ‡บ๐Ÿ‡ธ English, ๐Ÿ‡ซ๐Ÿ‡ท French, ๐Ÿ‡ฎ๐Ÿ‡น Italian, ๐Ÿ‡ช๐Ÿ‡ธ Spanish, ๐Ÿ‡ต๐Ÿ‡น Portuguese, ๐Ÿ‡ฉ๐Ÿ‡ช German, ๐Ÿ‡ป๐Ÿ‡ณ Vietnamese, ๐Ÿ‡ฐ๐Ÿ‡ท Korean and ๐Ÿ‡จ๐Ÿ‡ณ Chinese.
34
 
35
- ๐Ÿ“‹ Note: current model version is "disl-0x5" (10 Jul 2024), context length 128k (123392 tokens) and current status is "moderating / previewing". For detailed information about the model, see [here](https://ghost-x.org/docs/models/ghost-8b-beta/). Try to experience it the way you want!
 
36
  """
37
 
38
 
@@ -251,19 +256,19 @@ if not torch.cuda.is_available():
251
 
252
  if torch.cuda.is_available():
253
  model_id = "ghost-x/ghost-8b-beta"
254
- model_tk = os.getenv("HF_TOKEN", None)
255
  model = AutoModelForCausalLM.from_pretrained(
256
  model_id,
257
  device_map="auto",
258
  torch_dtype=torch.bfloat16,
259
  attn_implementation="flash_attention_2",
260
  trust_remote_code=True,
261
- token=model_tk,
262
  )
263
  tokenizer = AutoTokenizer.from_pretrained(
264
  model_id,
265
  trust_remote_code=True,
266
- token=model_tk,
267
  )
268
  SelfExtend.apply(
269
  model,
@@ -274,73 +279,259 @@ if torch.cuda.is_available():
274
  )
275
  model.generation_config.max_length = 123392
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  def generate(
280
  message: str,
281
  chat_history: list[tuple[str, str]],
282
- system_prompt: str,
283
- max_new_tokens: int = 1536,
 
284
  temperature: float = 0.4,
285
  top_p: float = 0.95,
286
  top_k: int = 50,
287
  repetition_penalty: float = 1.0,
288
  ) -> Iterator[str]:
289
- conversation = []
290
- if system_prompt:
291
- conversation.append({"role": "system", "content": system_prompt})
292
- for user, assistant in chat_history:
293
- conversation.extend(
294
- [
295
- {"role": "user", "content": user},
296
- {"role": "assistant", "content": assistant},
297
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  )
299
- conversation.append({"role": "user", "content": message})
 
 
 
 
 
 
300
 
301
- input_ids = tokenizer.apply_chat_template(
302
- conversation, add_generation_prompt=True, return_tensors="pt"
303
- )
304
- input_ids = input_ids.to(model.device)
305
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
306
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
307
- gr.Warning(
308
- f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  )
 
 
 
 
 
 
310
 
311
- streamer = TextIteratorStreamer(
312
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
313
- )
314
- generate_kwargs = dict(
315
- input_ids=input_ids,
316
- streamer=streamer,
317
- max_new_tokens=max_new_tokens,
318
- do_sample=True,
319
- repetition_penalty=repetition_penalty,
320
- )
321
- if temperature == 0:
322
- generate_kwargs["do_sample"] = False
323
- else:
324
- generate_kwargs["temperature"] = temperature
325
- generate_kwargs["top_p"] = top_p
326
- generate_kwargs["top_k"] = top_k
327
-
328
- t = Thread(target=model.generate, kwargs=generate_kwargs)
329
- t.start()
330
 
331
- outputs = []
332
- for text in streamer:
333
- outputs.append(text)
334
- yield "".join(outputs)
 
 
 
335
 
 
336
 
337
- chatbot = gr.Chatbot(height=500, placeholder=PLACEHOLDER, label="Ghost 8B Beta")
 
 
 
338
 
339
  chat_interface = gr.ChatInterface(
340
  fn=generate,
341
  chatbot=chatbot,
342
  fill_height=True,
343
  additional_inputs=[
 
 
 
344
  gr.Textbox(label="System prompt", lines=6),
345
  gr.Slider(
346
  label="Max new tokens",
@@ -382,6 +573,7 @@ chat_interface = gr.ChatInterface(
382
  cache_examples=False,
383
  examples=EXAMPLES,
384
  examples_per_page=9,
 
385
  )
386
 
387
  with gr.Blocks(fill_height=True, css="style.css") as demo:
@@ -391,4 +583,3 @@ with gr.Blocks(fill_height=True, css="style.css") as demo:
391
 
392
  if __name__ == "__main__":
393
  demo.queue(max_size=20).launch(share=True)
394
- # demo.launch(share=True)
 
1
  # pylint: skip-file
2
 
3
  import subprocess
4
+ import json
5
+ import requests
6
 
7
  subprocess.run(
8
  f"pip install flash-attn --no-build-isolation",
 
17
  import gradio as gr
18
  import spaces
19
  import torch
20
+ import wikipedia
21
+ import time
22
  import SelfExtend
23
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
24
+ from bs4 import BeautifulSoup
25
+ from functools import lru_cache
26
 
27
 
28
+ MAX_MAX_NEW_TOKENS = 8192
29
+ DEFAULT_MAX_NEW_TOKENS = 2048
30
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "123392"))
31
 
32
  DESCRIPTION = """\
33
+ # Playground with Ghost 8B Beta (ฮฒ, 8k)
34
 
35
+ **Ghost 8B Beta** model outperforms prominent models such as Llama 3 8B Instruct, GPT 3.5 Turbo in the lc_winrate score. In addition, it also outperforms Claude 3 Opus, Claude 3 Sonnet, GPT-4, and Mistral Large when comparing the winrate score of AlpacaEval 2.0, [*](https://ghost-x.org/docs/models/ghost-8b-beta/). The model comes in two context length versions, [8k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-8k) and [128k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-128k), along with multilingual function tools support by default.
 
 
36
 
37
  The languages supported are ๐Ÿ‡บ๐Ÿ‡ธ English, ๐Ÿ‡ซ๐Ÿ‡ท French, ๐Ÿ‡ฎ๐Ÿ‡น Italian, ๐Ÿ‡ช๐Ÿ‡ธ Spanish, ๐Ÿ‡ต๐Ÿ‡น Portuguese, ๐Ÿ‡ฉ๐Ÿ‡ช German, ๐Ÿ‡ป๐Ÿ‡ณ Vietnamese, ๐Ÿ‡ฐ๐Ÿ‡ท Korean and ๐Ÿ‡จ๐Ÿ‡ณ Chinese.
38
 
39
+ ๐Ÿ—ž๏ธ **Updates**
40
+ * Jul 23, 2024: added support for tools, now available to search for information on the internet.
41
  """
42
 
43
 
 
256
 
257
  if torch.cuda.is_available():
258
  model_id = "ghost-x/ghost-8b-beta"
259
+ hf_serect = os.getenv("HF_TOKEN", None)
260
  model = AutoModelForCausalLM.from_pretrained(
261
  model_id,
262
  device_map="auto",
263
  torch_dtype=torch.bfloat16,
264
  attn_implementation="flash_attention_2",
265
  trust_remote_code=True,
266
+ token=hf_serect,
267
  )
268
  tokenizer = AutoTokenizer.from_pretrained(
269
  model_id,
270
  trust_remote_code=True,
271
+ token=hf_serect,
272
  )
273
  SelfExtend.apply(
274
  model,
 
279
  )
280
  model.generation_config.max_length = 123392
281
 
282
+ waiting_tools_timeout = 7.5
283
+ supported_tools = json.dumps(
284
+ [
285
+ {
286
+ "type": "function",
287
+ "function": {
288
+ "name": "search_on_internet",
289
+ "description": "Use this tool to search online, only use it for information you don't know or are unsure of, don't abuse it.",
290
+ "parameters": {
291
+ "type": "object",
292
+ "properties": {
293
+ "keyword": {
294
+ "type": "string",
295
+ "description": "Search keywords, rephrase to optimize search results based on questions suitable to the specified search type.",
296
+ "required": True,
297
+ },
298
+ "type": {
299
+ "type": "string",
300
+ "description": "Search type, based on the question to determine whether to search for it in 'wikipedia' or 'google', prefer to use wikipedia for information about events, history and people.",
301
+ "enum": ["wikipedia", "google"],
302
+ "default": "google",
303
+ "required": True,
304
+ },
305
+ },
306
+ },
307
+ },
308
+ }
309
+ ],
310
+ ensure_ascii=False,
311
+ )
312
+
313
+
314
+ @lru_cache(maxsize=128)
315
+ def extract_text_from_webpage(html_content):
316
+ soup = BeautifulSoup(html_content, "html.parser")
317
+ for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
318
+ tag.extract()
319
+ visible_text = soup.get_text(strip=True, separator=" ")
320
+ return visible_text
321
+
322
+
323
+ def search_with_wikipedia(query: str):
324
+ all_results = []
325
+ try:
326
+ all_results.append(wikipedia.summary(query))
327
+ except Exception as e:
328
+ pass
329
+ return all_results
330
 
331
+
332
+ def search_with_google(
333
+ query: str,
334
+ num_results: int = 3,
335
+ timeout: int = 5,
336
+ ssl_verify: bool = None,
337
+ ):
338
+ all_results = []
339
+ max_chars_per_page = 4096
340
+ with requests.Session() as session:
341
+ resp = session.get(
342
+ url="https://www.google.com/search",
343
+ headers={
344
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
345
+ },
346
+ params={
347
+ "q": query,
348
+ "num": num_results,
349
+ "udm": 14,
350
+ },
351
+ timeout=timeout,
352
+ verify=ssl_verify,
353
+ )
354
+ resp.raise_for_status()
355
+ soup = BeautifulSoup(resp.text, "html.parser")
356
+ result_block = soup.find_all("div", attrs={"class": "g"})
357
+ for result in result_block:
358
+ link = result.find("a", href=True)
359
+ if link:
360
+ link = link["href"]
361
+ try:
362
+ webpage = session.get(
363
+ link,
364
+ headers={
365
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
366
+ },
367
+ )
368
+ webpage.raise_for_status()
369
+ visible_text = extract_text_from_webpage(webpage.text)
370
+ if len(visible_text) > max_chars_per_page:
371
+ visible_text = visible_text[:max_chars_per_page]
372
+ all_results.append({"link": link, "text": visible_text})
373
+ except requests.exceptions.RequestException as e:
374
+ print(f"Error fetching or processing {link}: {e}")
375
+ pass
376
+ else:
377
+ pass
378
+ return all_results
379
+
380
+
381
+ @spaces.GPU(duration=180)
382
  def generate(
383
  message: str,
384
  chat_history: list[tuple[str, str]],
385
+ allow_used_tools: bool = True,
386
+ system_prompt: str = "",
387
+ max_new_tokens: int = 2048,
388
  temperature: float = 0.4,
389
  top_p: float = 0.95,
390
  top_k: int = 50,
391
  repetition_penalty: float = 1.0,
392
  ) -> Iterator[str]:
393
+ # print()
394
+ # print("allow_used_tools:\n", allow_used_tools)
395
+ # print("system_prompt:\n", system_prompt)
396
+ # print("max_new_tokens:\n", max_new_tokens)
397
+ # print("temperature:\n", temperature)
398
+
399
+ def build_input_ids(
400
+ apply_tools: bool = None,
401
+ references: list[str] = None,
402
+ ):
403
+ conversation = []
404
+ if system_prompt:
405
+ conversation.append({"role": "system", "content": system_prompt})
406
+ if apply_tools is True:
407
+ conversation.append({"role": "tools", "content": supported_tools})
408
+ if (
409
+ references is not None
410
+ and isinstance(references, list)
411
+ and len(references) > 0
412
+ ):
413
+ conversation.append(
414
+ {
415
+ "role": "refs",
416
+ "content": json.dumps(references, ensure_ascii=False),
417
+ }
418
+ )
419
+
420
+ for user, assistant in chat_history:
421
+ conversation.extend(
422
+ [
423
+ {"role": "user", "content": user},
424
+ {"role": "assistant", "content": assistant},
425
+ ]
426
+ )
427
+ conversation.append({"role": "user", "content": message})
428
+
429
+ input_ids = tokenizer.apply_chat_template(
430
+ conversation, add_generation_prompt=True, return_tensors="pt"
431
  )
432
+ input_ids = input_ids.to(model.device)
433
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
434
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
435
+ gr.Warning(
436
+ f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
437
+ )
438
+ return input_ids
439
 
440
+ def generate_chat_responses(
441
+ previous_response: str = None,
442
+ ):
443
+ document_references = []
444
+ if previous_response is not None:
445
+ scheduled_tools_runs = None
446
+ try:
447
+ scheduled_tools_runs = json.loads(previous_response)
448
+ if scheduled_tools_runs["type"] == "function" and scheduled_tools_runs[
449
+ "name"
450
+ ] in ["search_on_internet"]:
451
+ pass
452
+ else:
453
+ scheduled_tools_runs = None
454
+ except Exception as e:
455
+ print(e)
456
+ pass
457
+
458
+ if (
459
+ scheduled_tools_runs is not None
460
+ and scheduled_tools_runs["name"] == "search_on_internet"
461
+ ):
462
+ keyword = scheduled_tools_runs["arguments"]["keyword"]
463
+ search_type = scheduled_tools_runs["arguments"]["type"]
464
+ if search_type == "wikipedia":
465
+ gr.Info("Searching for information on the Wikipedia.")
466
+ document_references = search_with_wikipedia(keyword)
467
+ else:
468
+ gr.Info("Searching for information on the Google.")
469
+ document_references = search_with_google(keyword)
470
+
471
+ input_ids = build_input_ids(
472
+ apply_tools=(
473
+ True
474
+ if allow_used_tools is True and previous_response is None
475
+ else False
476
+ ),
477
+ references=document_references,
478
+ )
479
+ streamer = TextIteratorStreamer(
480
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
481
+ )
482
+ generate_kwargs = dict(
483
+ input_ids=input_ids,
484
+ streamer=streamer,
485
+ max_new_tokens=max_new_tokens,
486
+ do_sample=True,
487
+ repetition_penalty=repetition_penalty,
488
  )
489
+ if temperature == 0:
490
+ generate_kwargs["do_sample"] = False
491
+ else:
492
+ generate_kwargs["temperature"] = temperature
493
+ generate_kwargs["top_p"] = top_p
494
+ generate_kwargs["top_k"] = top_k
495
 
496
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
497
+ t.start()
498
+
499
+ state = {
500
+ "mark": None,
501
+ "respond": False,
502
+ }
503
+ outputs = []
504
+ for text in streamer:
505
+ if state["mark"] is None:
506
+ state["mark"] = time.time()
507
+ outputs.append(text)
508
+ if state["mark"] + waiting_tools_timeout < time.time():
509
+ state["respond"] = True
510
+ yield "".join(outputs)
 
 
 
 
511
 
512
+ if (
513
+ state["respond"] is False
514
+ and state["mark"] + waiting_tools_timeout > time.time()
515
+ ):
516
+ gr.Info("Searching for information on the internet.")
517
+ previous_response = "".join(outputs)
518
+ yield from generate_chat_responses(previous_response=previous_response)
519
 
520
+ yield from generate_chat_responses(previous_response=None)
521
 
522
+
523
+ chatbot = gr.Chatbot(
524
+ height=500, placeholder=PLACEHOLDER, label="Ghost 8B Beta", show_copy_button=True
525
+ )
526
 
527
  chat_interface = gr.ChatInterface(
528
  fn=generate,
529
  chatbot=chatbot,
530
  fill_height=True,
531
  additional_inputs=[
532
+ gr.Checkbox(
533
+ label="Allow used tools (available: search on internet)", value=True
534
+ ),
535
  gr.Textbox(label="System prompt", lines=6),
536
  gr.Slider(
537
  label="Max new tokens",
 
573
  cache_examples=False,
574
  examples=EXAMPLES,
575
  examples_per_page=9,
576
+ concurrency_limit=100,
577
  )
578
 
579
  with gr.Blocks(fill_height=True, css="style.css") as demo:
 
583
 
584
  if __name__ == "__main__":
585
  demo.queue(max_size=20).launch(share=True)
 
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
  accelerate==0.30.1
2
  bitsandbytes==0.43.1
3
- gradio==4.37.2
4
  scipy==1.13.0
5
  sentencepiece==0.2.0
6
  spaces==0.28.3
7
  torch==2.0.0
8
  transformers==4.41.0
 
 
 
1
  accelerate==0.30.1
2
  bitsandbytes==0.43.1
3
+ gradio==4.39.0
4
  scipy==1.13.0
5
  sentencepiece==0.2.0
6
  spaces==0.28.3
7
  torch==2.0.0
8
  transformers==4.41.0
9
+ beautifulsoup4>=4.9
10
+ wikipedia==1.4.0