kingabzpro commited on
Commit
9d1c8f9
1 Parent(s): 3c20001

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -28
app.py CHANGED
@@ -1,6 +1,11 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
 
 
 
 
 
4
 
5
 
6
  title = "🦅Falcon 🗨️ChatBot"
@@ -12,54 +17,67 @@ tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b")
12
  model = AutoModelForCausalLM.from_pretrained(
13
  "tiiuae/falcon-rw-1b",
14
  trust_remote_code=True,
15
- torch_dtype=torch.float16
 
16
  )
17
 
18
 
19
  class StopOnTokens(StoppingCriteria):
20
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
21
- stop_ids = [29, 0]
22
  for stop_id in stop_ids:
23
  if input_ids[0][-1] == stop_id:
24
  return True
25
  return False
26
 
27
- def predict(message, history):
28
 
29
- history_transformer_format = history + [[message, ""]]
 
 
 
 
 
 
30
  stop = StopOnTokens()
31
 
32
- #Construct the input message string for the model by concatenating the current system message and conversation history
33
- messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
34
- for item in history_transformer_format])
 
 
 
 
 
 
 
 
 
35
 
36
- #Tokenize the messages string
37
- model_inputs = tokenizer([messages], return_tensors="pt")
38
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
39
  generate_kwargs = dict(
40
- model_inputs,
41
- streamer=streamer,
42
- max_new_tokens=1024,
 
43
  do_sample=True,
44
- top_p=0.95,
45
- top_k=1000,
46
- temperature=1.0,
47
- num_beams=1,
48
  stopping_criteria=StoppingCriteriaList([stop])
49
- )
50
- #t = Thread(target=model.generate, kwargs=generate_kwargs)
51
- #t.start()
52
- model.generate(**generate_kwargs)
53
 
54
  #Initialize an empty string to store the generated text
55
- partial_message = ""
56
- for new_token in streamer:
57
- if new_token != '<':
58
- partial_message += new_token
59
- yield partial_message
60
-
 
 
61
 
62
- gr.ChatInterface(predict,
63
  title=title,
64
  description=description,
65
  examples=examples,
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
+ import time
5
+ import numpy as np
6
+ from torch.nn import functional as F
7
+ import os
8
+ from threading import Thread
9
 
10
 
11
  title = "🦅Falcon 🗨️ChatBot"
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
  "tiiuae/falcon-rw-1b",
19
  trust_remote_code=True,
20
+ torch_dtype=torch.float16,
21
+ load_in_8bit=True
22
  )
23
 
24
 
25
  class StopOnTokens(StoppingCriteria):
26
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
27
+ stop_ids = [0]
28
  for stop_id in stop_ids:
29
  if input_ids[0][-1] == stop_id:
30
  return True
31
  return False
32
 
 
33
 
34
+ def user(message, history):
35
+ # Append the user's message to the conversation history
36
+ return "", history + [[message, ""]]
37
+
38
+
39
+ def chat(curr_system_message, history):
40
+ # Initialize a StopOnTokens object
41
  stop = StopOnTokens()
42
 
43
+ # Construct the input message string for the model by concatenating the current system message and conversation history
44
+ messages = curr_system_message + \
45
+ "".join(["".join(["<user>: "+item[0], "<chatbot>: "+item[1]])
46
+ for item in history])
47
+
48
+ # Tokenize the messages string
49
+ tokens = tokenizer([messages], return_tensors="pt")
50
+ streamer = TextIteratorStreamer(
51
+ tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
52
+
53
+ token_ids = tokens.input_ids
54
+ attention_mask=tokens.attention_mask
55
 
 
 
 
56
  generate_kwargs = dict(
57
+ input_ids=token_ids,
58
+ attention_mask = attention_mask,
59
+ streamer = streamer,
60
+ max_length=2048,
61
  do_sample=True,
62
+ num_return_sequences=1,
63
+ eos_token_id=tokenizer.eos_token_id,
64
+ temperature = 0.7,
 
65
  stopping_criteria=StoppingCriteriaList([stop])
66
+ )
67
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
68
+ t.start()
 
69
 
70
  #Initialize an empty string to store the generated text
71
+ partial_text = ""
72
+ for new_text in streamer:
73
+ # print(new_text)
74
+ partial_text += new_text
75
+ history[-1][1] = partial_text
76
+ # Yield an empty string to cleanup the message textbox and the updated conversation history
77
+ yield history
78
+ return partial_text
79
 
80
+ gr.ChatInterface(chat,
81
  title=title,
82
  description=description,
83
  examples=examples,