vonewman commited on
Commit
d878d30
1 Parent(s): f1c2874

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +47 -0
main.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TextIteratorStreamer
2
+ from threading import Thread
3
+ import gradio as gr
4
+
5
+
6
+ MAX_INPUT_TOKEN_LENGTH = 4096
7
+
8
+
9
+ def generate(message, chat_history):
10
+ # Step 1: pre-process the inputs
11
+ conversation = []
12
+ for user, assistant in chat_history:
13
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
14
+
15
+ conversation.append({"role": "user", "content": message})
16
+
17
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
18
+
19
+ # in-case our inputs exceed the maximum length, we might need to cut them
20
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
21
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
22
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
23
+
24
+ input_ids = input_ids.to(model.device)
25
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
26
+
27
+ # Step 2: define generation arguments
28
+ generate_kwargs = dict(
29
+ {"input_ids": input_ids},
30
+ streamer=streamer,
31
+ max_new_tokens=1024,
32
+ do_sample=True,
33
+ )
34
+
35
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
36
+ t.start()
37
+
38
+ # Step 3: generate and stream outputs
39
+ outputs = ""
40
+ for text in streamer:
41
+ outputs += text
42
+ yield outputs
43
+
44
+
45
+
46
+ chat_interface = gr.ChatInterface(generate)
47
+ chat_interface.queue().launch(share=True)