macadeliccc commited on
Commit
f7d8c6a
1 Parent(s): ce3745c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -51
app.py CHANGED
@@ -1,58 +1,82 @@
1
  import spaces
2
  import gradio as gr
3
  import torch
4
- from gradio import State
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
6
 
7
- # Select the device (GPU if available, else CPU)
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
- # Load the tokenizer and model
11
- tokenizer = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
12
- model = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha").to(device)
13
- model.eval() # Set the model to evaluation mode
 
 
 
 
14
 
 
 
15
  @spaces.GPU
16
- def generate_response(user_input, chat_history):
17
- try:
18
- prompt = "GPT4 Correct User: " + user_input + "GPT4 Correct Assistant: "
19
- if chat_history:
20
- prompt = chat_history[-1024:] + prompt # Keep last 1024 tokens of history
21
-
22
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
23
- inputs = {k: v.to(device) for k, v in inputs.items()} # Move input tensors to the same device as the model
24
-
25
- with torch.no_grad():
26
- output = model.generate(**inputs, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
27
-
28
- response = tokenizer.decode(output[0], skip_special_tokens=True)
29
- new_history = chat_history + prompt + response
30
- return response, new_history[-1024:] # Return last 1024 tokens of history
31
-
32
- except Exception as e:
33
- return f"Error occurred: {e}", chat_history
34
-
35
- # Gradio Interface
36
- def clear_chat():
37
- return "", ""
38
-
39
- with gr.Blocks(gr.themes.Soft()) as app:
40
- with gr.Row():
41
- gr.Markdown("## Starling Chatbot")
42
- gr.Markdown("Run with your own hardware. This application exceeds 24GB VRAM")
43
- gr.Markdown("```docker run -it -p 7860:7860 --platform=linux/amd64 --gpus all \
44
- registry.hf.space/macadeliccc-starling-lm-7b-alpha-chat:latest python app.py```")
45
- with gr.Row():
46
- chatbot = gr.Chatbot()
47
-
48
- with gr.Row():
49
- user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
50
- send = gr.Button("Send")
51
- clear = gr.Button("Clear")
52
-
53
- chat_history = gr.State() # Holds the chat history
54
-
55
- send.click(generate_response, inputs=[user_input, chat_history], outputs=[chatbot, chat_history])
56
- clear.click(clear_chat, outputs=[chatbot, chat_history])
57
-
58
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
  import gradio as gr
3
  import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
6
+ from threading import Thread
7
 
8
+ # Lazy loading the model to meet huggingface stateless GPU requirements
 
9
 
10
+ # Defining a custom stopping criteria class for the model's text generation.
11
+ class StopOnTokens(StoppingCriteria):
12
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
13
+ stop_ids = [50256, 50295] # IDs of tokens where the generation should stop.
14
+ for stop_id in stop_ids:
15
+ if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
16
+ return True
17
+ return False
18
 
19
+
20
+ # Function to generate model predictions.
21
  @spaces.GPU
22
+ def predict(message, history):
23
+ torch.set_default_device("cuda")
24
+
25
+ # Loading the tokenizer and model from Hugging Face's model hub.
26
+ tokenizer = AutoTokenizer.from_pretrained(
27
+ "macadeliccc/laser-dolphin-mixtral-2x7b-dpo",
28
+ trust_remote_code=True
29
+ )
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ "macadeliccc/laser-dolphin-mixtral-2x7b-dpo",
32
+ torch_dtype="auto",
33
+ load_in_4bit=True,
34
+ trust_remote_code=True
35
+ )
36
+ history_transformer_format = history + [[message, ""]]
37
+ stop = StopOnTokens()
38
+
39
+ # Formatting the input for the model.
40
+ system_prompt = "<|im_start|>system\nYou are Dolphin, a helpful AI assistant.<|im_end|>"
41
+ messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
42
+ input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
43
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
44
+ generate_kwargs = dict(
45
+ input_ids,
46
+ streamer=streamer,
47
+ max_new_tokens=1024,
48
+ do_sample=True,
49
+ top_p=0.95,
50
+ top_k=50,
51
+ temperature=0.7,
52
+ num_beams=1,
53
+ stopping_criteria=StoppingCriteriaList([stop])
54
+ )
55
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
56
+ t.start() # Starting the generation in a separate thread.
57
+ partial_message = ""
58
+ for new_token in streamer:
59
+ partial_message += new_token
60
+ if '<|im_end|>' in partial_message: # Breaking the loop if the stop token is generated.
61
+ break
62
+ yield partial_message
63
+
64
+
65
+ # Setting up the Gradio chat interface.
66
+ gr.ChatInterface(predict,
67
+ description="""
68
+ <center><img src="https://huggingface.co/macadeliccc/laser-dolphin-mixtral-2x7b-dpo/resolve/main/dolphin_moe.png" width="33%"></center>\n\n
69
+ Chat with [macadeliccc/SOLAR-math-2x10.7b-v0.2](https://huggingface.co/macadeliccc/SOLAR-math-2x10.7b-v0.2), the first Mixture of Experts made by merging two fine-tuned [upstage/SOLAR-10.7B-v1.0](https://huggingface.co/upstage/SOLAR-10.7B-v1.0) models.
70
+ This model (19.2B param) scores top 5 on several evaluations. Output is considered experimental.\n\n
71
+ ❤️ If you like this work, please follow me on [Hugging Face](https://huggingface.co/macadeliccc) and [LinkedIn](https://www.linkedin.com/in/tim-dolan-python-dev/).
72
+ """,
73
+ examples=[
74
+ 'Can you solve the equation 2x + 3 = 11 for x?',
75
+ 'How does Fermats last theorem impact number theory?',
76
+ 'What is a vector in the scope of computer science rather than physics?',
77
+ 'Use a list comprehension to create a list of squares for numbers from 1 to 10.',
78
+ 'Recommend some popular science fiction books.',
79
+ 'Can you write a short story about a time-traveling detective?'
80
+ ],
81
+ theme=gr.themes.Soft(primary_hue="purple"),
82
+ ).launch()