switiz87 commited on
Commit
3aa7c64
1 Parent(s): 5e4cf3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -14,7 +14,7 @@ DESCRIPTION = """ EXAONE-3.0-7.8B-Instruct Official Demo \
14
  """
15
 
16
  MAX_MAX_NEW_TOKENS = 4096
17
- DEFAULT_MAX_NEW_TOKENS = 256
18
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "3840"))
19
 
20
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -30,12 +30,12 @@ model = AutoModelForCausalLM.from_pretrained(
30
  model.eval()
31
 
32
 
33
- #@spaces.GPU(duration=90)
34
  def generate(
35
  message: str,
36
  chat_history: list[tuple[str, str]],
37
  system_prompt: str,
38
- max_new_tokens: int = 256,
39
  temperature: float = 0.6,
40
  top_p: float = 0.9,
41
  top_k: int = 50,
@@ -63,7 +63,7 @@ def generate(
63
  gr.Warning(f"Trimmed input from messages as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
64
  input_ids = input_ids.to(model.device)
65
 
66
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
67
  generate_kwargs = dict(
68
  {"input_ids": input_ids},
69
  streamer=streamer,
 
14
  """
15
 
16
  MAX_MAX_NEW_TOKENS = 4096
17
+ DEFAULT_MAX_NEW_TOKENS = 128
18
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "3840"))
19
 
20
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
30
  model.eval()
31
 
32
 
33
+ @spaces.GPU()
34
  def generate(
35
  message: str,
36
  chat_history: list[tuple[str, str]],
37
  system_prompt: str,
38
+ max_new_tokens: int = 128,
39
  temperature: float = 0.6,
40
  top_p: float = 0.9,
41
  top_k: int = 50,
 
63
  gr.Warning(f"Trimmed input from messages as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
64
  input_ids = input_ids.to(model.device)
65
 
66
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
67
  generate_kwargs = dict(
68
  {"input_ids": input_ids},
69
  streamer=streamer,