PetrovDE commited on
Commit
04c4044
1 Parent(s): 48a1d18

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +135 -0
  2. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from peft import LoraConfig, PeftModel
4
+
5
+ from threading import Thread
6
+ from typing import Iterator
7
+
8
+ import gradio as gr
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+
14
+ model_name = "WeeRobots/phi-2-chat-v05"
15
+ new_model = "phi-2-sheldon"
16
+
17
+ base_model = AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ low_cpu_mem_usage=True,
20
+ return_dict=True,
21
+ torch_dtype=torch.float32,
22
+ trust_remote_code=True
23
+ )
24
+ model = PeftModel.from_pretrained(base_model, new_model)
25
+ model = model.merge_and_unload()
26
+
27
+ # Reload tokenizer to save it
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
29
+ tokenizer.pad_token = tokenizer.eos_token
30
+ tokenizer.padding_side = "right"
31
+
32
+ MAX_MAX_NEW_TOKENS = 200
33
+ DEFAULT_MAX_NEW_TOKENS = 100
34
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
35
+
36
+
37
+ def generate(
38
+ message: str,
39
+ chat_history: list[tuple[str, str]],
40
+ system_prompt: str,
41
+ max_new_tokens: int = 200,
42
+ temperature: float = 0.6,
43
+ top_p: float = 0.9,
44
+ top_k: int = 50,
45
+ repetition_penalty: float = 1.2,
46
+ ) -> Iterator[str]:
47
+ conversation = []
48
+ if system_prompt:
49
+ conversation.append({"role": "system", "content": system_prompt})
50
+ for user, assistant in chat_history:
51
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
52
+ conversation.append({"role": "user", "content": message})
53
+
54
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
55
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
56
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
57
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
58
+ input_ids = input_ids.to(model.device)
59
+
60
+ streamer = TextIteratorStreamer(tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True)
61
+ generate_kwargs = dict(
62
+ {"input_ids": input_ids},
63
+ streamer=streamer,
64
+ max_new_tokens=max_new_tokens,
65
+ do_sample=True,
66
+ top_p=top_p,
67
+ top_k=top_k,
68
+ temperature=temperature,
69
+ num_beams=1,
70
+ repetition_penalty=repetition_penalty,
71
+ eos_token_id=tokenizer.eos_token_id,
72
+ )
73
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
74
+ t.start()
75
+
76
+ outputs = []
77
+ for text in streamer:
78
+ print(text)
79
+ outputs.append(text)
80
+ yield "".join(outputs)
81
+
82
+
83
+ chat_interface = gr.ChatInterface(
84
+ fn=generate,
85
+ additional_inputs=[
86
+ gr.Textbox(label="System prompt", lines=3),
87
+ gr.Slider(
88
+ label="Max new tokens",
89
+ minimum=1,
90
+ maximum=MAX_MAX_NEW_TOKENS,
91
+ step=1,
92
+ value=DEFAULT_MAX_NEW_TOKENS,
93
+ ),
94
+ gr.Slider(
95
+ label="Temperature",
96
+ minimum=0.1,
97
+ maximum=4.0,
98
+ step=0.1,
99
+ value=0.6,
100
+ ),
101
+ gr.Slider(
102
+ label="Top-p (nucleus sampling)",
103
+ minimum=0.05,
104
+ maximum=1.0,
105
+ step=0.05,
106
+ value=0.9,
107
+ ),
108
+ gr.Slider(
109
+ label="Top-k",
110
+ minimum=1,
111
+ maximum=1000,
112
+ step=1,
113
+ value=50,
114
+ ),
115
+ gr.Slider(
116
+ label="Repetition penalty",
117
+ minimum=1.0,
118
+ maximum=2.0,
119
+ step=0.05,
120
+ value=1.2,
121
+ ),
122
+ ],
123
+ stop_btn=None,
124
+ examples=[
125
+ ["Hello there! How are you doing?"],
126
+ ["How many hours does it take a man to eat a Helicopter?"],
127
+ ["In a Sperm bank"],
128
+ ],
129
+ )
130
+
131
+ with gr.Blocks() as demo:
132
+ chat_interface.render()
133
+
134
+ if __name__ == "__main__":
135
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
Binary file (120 Bytes). View file