vilarin commited on
Commit
22f5f54
1 Parent(s): 5300ae4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -35
app.py CHANGED
@@ -2,18 +2,18 @@ import torch
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
- MODEL_ID = "CohereForAI/aya-23-8B"
12
- MODEL_ID2 = "CohereForAI/aya-23-35B"
13
  MODELS = os.environ.get("MODELS")
14
  MODEL_NAME = MODELS.split("/")[-1]
15
 
16
- TITLE = "<h1><center>Aya-23-Chatbox</center></h1>"
17
 
18
  DESCRIPTION = f'<h3><center>MODEL: <a href="https://hf.co/{MODELS}">{MODEL_NAME}</a></center></h3>'
19
 
@@ -26,37 +26,14 @@ CSS = """
26
  }
27
  """
28
 
29
-
30
- #QUANTIZE
31
- QUANTIZE_4BIT = True
32
- USE_GRAD_CHECKPOINTING = True
33
- TRAIN_BATCH_SIZE = 2
34
- TRAIN_MAX_SEQ_LENGTH = 512
35
- USE_FLASH_ATTENTION = False
36
- GRAD_ACC_STEPS = 16
37
-
38
- quantization_config = None
39
-
40
- if QUANTIZE_4BIT:
41
- quantization_config = BitsAndBytesConfig(
42
- load_in_4bit=True,
43
- bnb_4bit_quant_type="nf4",
44
- bnb_4bit_use_double_quant=True,
45
- bnb_4bit_compute_dtype=torch.bfloat16,
46
- )
47
-
48
- attn_implementation = None
49
- if USE_FLASH_ATTENTION:
50
- attn_implementation="flash_attention_2"
51
-
52
  model = AutoModelForCausalLM.from_pretrained(
53
- MODELS,
54
- quantization_config=quantization_config,
55
- attn_implementation=attn_implementation,
56
- torch_dtype=torch.bfloat16,
57
- device_map="auto",
58
- )
59
- tokenizer = AutoTokenizer.from_pretrained(MODELS)
60
 
61
  @spaces.GPU
62
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
@@ -69,7 +46,7 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
69
 
70
  print(f"Conversation is -\n{conversation}")
71
 
72
- input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
73
 
74
  streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
75
 
@@ -79,6 +56,8 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
79
  max_new_tokens=max_new_tokens,
80
  do_sample=True,
81
  temperature=temperature,
 
 
82
  )
83
 
84
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
 
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL_ID = "THUDM/glm-4-9b-chat"
12
+ MODEL_ID2 = "THUDM/glm-4-9b-chat-1m"
13
  MODELS = os.environ.get("MODELS")
14
  MODEL_NAME = MODELS.split("/")[-1]
15
 
16
+ TITLE = "<h1><center>GLM-4-9B</center></h1>"
17
 
18
  DESCRIPTION = f'<h3><center>MODEL: <a href="https://hf.co/{MODELS}">{MODEL_NAME}</a></center></h3>'
19
 
 
26
  }
27
  """
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  model = AutoModelForCausalLM.from_pretrained(
30
+ MODELS,
31
+ torch_dtype=torch.bfloat16,
32
+ low_cpu_mem_usage=True,
33
+ trust_remote_code=True,
34
+ ).to(0).eval()
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(MODELS,trust_remote_code=True)
37
 
38
  @spaces.GPU
39
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
 
46
 
47
  print(f"Conversation is -\n{conversation}")
48
 
49
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
50
 
51
  streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
52
 
 
56
  max_new_tokens=max_new_tokens,
57
  do_sample=True,
58
  temperature=temperature,
59
+ repetition_penalty=1.2,
60
+ eos_token_id=model.config.eos_token_id,
61
  )
62
 
63
  thread = Thread(target=model.generate, kwargs=generate_kwargs)