stream output and read from history

#1
Files changed (1) hide show
  1. app.py +29 -7
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import os
3
  import spaces
4
- from transformers import GemmaTokenizer, AutoModelForCausalLM
 
5
 
6
  # Set an environment variable
7
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -24,14 +25,35 @@ def codegemma(message: str, history: list, temperature: float, max_new_tokens: i
24
  Returns:
25
  str: The generated response.
26
  """
27
- input_ids = tokenizer(message, return_tensors="pt").to("cuda:0")
28
- outputs = model.generate(
29
- **input_ids,
30
- temperature=temperature,
 
 
 
 
 
 
 
 
 
 
 
31
  max_new_tokens=max_new_tokens,
 
32
  )
33
- response = tokenizer.decode(outputs[0])
34
- return response
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  placeholder = """
 
1
  import gradio as gr
2
  import os
3
  import spaces
4
+ from transformers import GemmaTokenizer, AutoModelForCausalLM, TextIteratorStreamer
5
+ from threading import Thread
6
 
7
  # Set an environment variable
8
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
25
  Returns:
26
  str: The generated response.
27
  """
28
+ chat = []
29
+ for item in history:
30
+ chat.append({"role": "user", "content": item[0]})
31
+ if item[1] is not None:
32
+ chat.append({"role": "assistant", "content": item[1]})
33
+ chat.append({"role": "user", "content": message})
34
+ messages = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
35
+ # Tokenize the messages string
36
+ model_inputs = tokenizer([messages], return_tensors="pt").to(device)
37
+ streamer = TextIteratorStreamer(
38
+ tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
39
+
40
+ generate_kwargs = dict(
41
+ model_inputs,
42
+ streamer=streamer,
43
  max_new_tokens=max_new_tokens,
44
+ temperature=temperature,
45
  )
46
+
47
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
48
+ t.start()
49
+
50
+ # Initialize an empty string to store the generated text
51
+ partial_text = ""
52
+ for new_text in streamer:
53
+ # print(new_text)
54
+ partial_text += new_text
55
+ # Yield an empty string to cleanup the message textbox and the updated conversation history
56
+ yield partial_text
57
 
58
 
59
  placeholder = """