xi0v commited on
Commit
d6164db
1 Parent(s): 6e43b71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -45
app.py CHANGED
@@ -1,65 +1,129 @@
1
  #!/usr/bin/env python
 
 
 
 
 
2
  import gradio as gr
3
  import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
6
- import time
7
- import numpy as np
8
- from torch.nn import functional as F
9
- import os
10
- from threading import Thread
11
 
12
- print(f"Starting to load the model to memory")
13
- m = AutoModelForCausalLM.from_pretrained(
14
- "xi0v/aether-7b-chat-v1.0", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=False)
15
- tok = AutoTokenizer.from_pretrained("xi0v/aether-7b-chat-v1.0", trust_remote_code=False)
16
- # using CUDA for an optimal experience
17
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
- m = m.to(device)
19
- print(f"Sucessfully loaded the model to the memory")
20
 
 
 
21
 
22
- start_message = "You are a Helpful assistant"
 
 
 
 
 
 
 
23
 
24
- def user(message, history):
25
- # Append the user's message to the conversation history
26
- return "", history + [[message, ""]]
27
 
28
  @spaces.GPU
29
- def chat(message, history):
30
- chat = []
31
- for item in history:
32
- chat.append({"role": "user", "content": item[0]})
33
- if item[1] is not None:
34
- chat.append({"role": "assistant", "content": item[1]})
35
- chat.append({"role": "user", "content": message})
36
- messages = tok.apply_chat_template(chat, tokenize=True, add_generation_prompt=True)
37
- # Tokenize the messages string
38
- model_inputs = tok([messages], return_tensors="pt").to(device)
39
- streamer = TextIteratorStreamer(
40
- tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
41
  generate_kwargs = dict(
42
- model_inputs,
43
  streamer=streamer,
44
- max_new_tokens=1024,
45
  do_sample=True,
46
- top_p=0.95,
47
- top_k=1000,
48
- temperature=0.75,
49
  num_beams=1,
 
50
  )
51
- t = Thread(target=m.generate, kwargs=generate_kwargs)
52
  t.start()
53
 
54
- # Initialize an empty string to store the generated text
55
- partial_text = ""
56
- for new_text in streamer:
57
- # print(new_text)
58
- partial_text += new_text
59
- # Yield an empty string to cleanup the message textbox and the updated conversation history
60
- yield partial_text
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- demo = gr.ChatInterface(fn=chat, examples=["hello", "hola", "merhaba"], title="Stable LM 2 Zephyr 1.6b")
65
- demo.launch()
 
1
  #!/usr/bin/env python
2
+
3
+ import os
4
+ from threading import Thread
5
+ from typing import Iterator
6
+
7
  import gradio as gr
8
  import spaces
9
  import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
 
 
11
 
12
+ DESCRIPTION = "# Aether-7b v1.0"
 
 
 
 
 
 
 
13
 
14
+ if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on GPU 🥶</p>"
16
 
17
+ MAX_MAX_NEW_TOKENS = 4096
18
+ DEFAULT_MAX_NEW_TOKENS = 1024
19
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
20
+
21
+ if torch.cuda.is_available():
22
+ model_id = "xi0v/aether-7b-chat-v1.0"
23
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
25
 
 
 
 
26
 
27
  @spaces.GPU
28
+ def generate(
29
+ message: str,
30
+ chat_history: list[tuple[str, str]],
31
+ system_prompt: str = "",
32
+ max_new_tokens: int = 1024,
33
+ temperature: float = 0.7,
34
+ top_p: float = 0.95,
35
+ top_k: int = 50,
36
+ repetition_penalty: float = 1.0,
37
+ ) -> Iterator[str]:
38
+ conversation = []
39
+ if system_prompt:
40
+ conversation.append({"role": "system", "content": system_prompt})
41
+ for user, assistant in chat_history:
42
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
43
+ conversation.append({"role": "user", "content": message})
44
+
45
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
46
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
47
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
48
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
49
+ input_ids = input_ids.to(model.device)
50
+
51
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
52
  generate_kwargs = dict(
53
+ {"input_ids": input_ids},
54
  streamer=streamer,
55
+ max_new_tokens=max_new_tokens,
56
  do_sample=True,
57
+ top_p=top_p,
58
+ top_k=top_k,
59
+ temperature=temperature,
60
  num_beams=1,
61
+ repetition_penalty=repetition_penalty,
62
  )
63
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
64
  t.start()
65
 
66
+ outputs = []
67
+ for text in streamer:
68
+ outputs.append(text)
69
+ yield "".join(outputs)
 
 
 
70
 
71
 
72
+ chat_interface = gr.ChatInterface(
73
+ fn=generate,
74
+ additional_inputs=[
75
+ gr.Textbox(
76
+ label="System prompt",
77
+ lines=6,
78
+ placeholder="You are a friendly chatbot who always responds in the style of a pirate.",
79
+ ),
80
+ gr.Slider(
81
+ label="Max new tokens",
82
+ minimum=1,
83
+ maximum=MAX_MAX_NEW_TOKENS,
84
+ step=1,
85
+ value=DEFAULT_MAX_NEW_TOKENS,
86
+ ),
87
+ gr.Slider(
88
+ label="Temperature",
89
+ minimum=0.1,
90
+ maximum=4.0,
91
+ step=0.1,
92
+ value=0.7,
93
+ ),
94
+ gr.Slider(
95
+ label="Top-p (nucleus sampling)",
96
+ minimum=0.05,
97
+ maximum=1.0,
98
+ step=0.05,
99
+ value=0.95,
100
+ ),
101
+ gr.Slider(
102
+ label="Top-k",
103
+ minimum=1,
104
+ maximum=1000,
105
+ step=1,
106
+ value=50,
107
+ ),
108
+ gr.Slider(
109
+ label="Repetition penalty",
110
+ minimum=1.0,
111
+ maximum=2.0,
112
+ step=0.05,
113
+ value=1.0,
114
+ ),
115
+ ],
116
+ stop_btn=None,
117
+ )
118
+
119
+ with gr.Blocks(css="style.css") as demo:
120
+ gr.Markdown(DESCRIPTION)
121
+ gr.DuplicateButton(
122
+ value="Duplicate Space for private use",
123
+ elem_id="duplicate-button",
124
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
125
+ )
126
+ chat_interface.render()
127
 
128
+ if __name__ == "__main__":
129
+ demo.queue(max_size=20).launch()