ubermenchh commited on
Commit
8400d16
1 Parent(s): 7fff83f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
2
+ from threading import Thread
3
+ import gradio as gr
4
+ import torch
5
+
6
+ MAX_INPUT_TOKEN_LENGTH = 4096
7
+
8
+ model_id = 'HuggingFaceH4/zephyr-7b-beta'
9
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map='auto')
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
11
+ tokenizer.use_default_system_prompt = False
12
+
13
+ def generate(input, chat_history=[], system_prompt=False, max_new_tokens=512, temperature=0.5, top_p=0.95, top_k=50, repetition_penalty=1.2):
14
+ conversation = []
15
+ if system_prompt:
16
+ conversation.append({
17
+ 'role': 'system',
18
+ 'content': system_prompt
19
+ })
20
+ for user, assistant in chat_history:
21
+ conversation.extend({
22
+ 'role': 'user',
23
+ 'content': user
24
+ },
25
+ {
26
+ 'role': 'assistant',
27
+ 'content': assistant
28
+ })
29
+ conversation.append({
30
+ 'role': 'user',
31
+ 'content': input
32
+ })
33
+
34
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors='pt')
35
+ if input_ids.shape[1] > MAXX_INPUT_TOKEN_LENGTH:
36
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
37
+ gr.Warning(f"Trimed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
38
+ input_ids = input_ids.to(model.device)
39
+
40
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
41
+ generate_kwargs = dict(
42
+ {'input_ids': input_ids},
43
+ streamer=streamer,
44
+ max_new_tokens=max_new_tokens,
45
+ do_sample=True,
46
+ top_p=top_p,
47
+ top_k=top_k,
48
+ temperature=temperature,
49
+ num_beams=1,
50
+ repetition_penalty=repetition_penalty
51
+ )
52
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
53
+ t.start()
54
+
55
+ outputs = []
56
+ for text in streamer:
57
+ outputs.append(text)
58
+ yield ''.join(outputs)
59
+
60
+ chat_interface = gr.ChatInterface(
61
+ fn=generate,
62
+ examples=[
63
+ 'What is GPT?',
64
+ 'What is Life?',
65
+ 'Who is Alan Turing'
66
+ ]
67
+ )
68
+
69
+ chat_interface.queue(max_size=20).launch()