thobuiq commited on
Commit
5059db6
1 Parent(s): 2acde50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -14
app.py CHANGED
@@ -1,23 +1,12 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
  from threading import Thread
6
 
7
-
8
- bnb_config = BitsAndBytesConfig(
9
- load_in_4bit=True,
10
- bnb_4bit_quant_type="nf4",
11
- bnb_4bit_use_double_quant=True,
12
- )
13
  # Loading the tokenizer and model from Hugging Face's model hub.
14
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
15
- model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1",
16
- load_in_4bit=True,
17
- quantization_config=bnb_config,
18
- torch_dtype=torch.bfloat16,
19
- device_map="cpu",
20
- trust_remote_code=True)
21
 
22
  # using CUDA for an optimal experience
23
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -34,11 +23,13 @@ class StopOnTokens(StoppingCriteria):
34
  return False
35
 
36
 
 
37
  # Function to generate model predictions.
38
  def predict(message, history):
39
  history_transformer_format = history + [[message, ""]]
40
  stop = StopOnTokens()
41
 
 
42
  # Formatting the input for the model.
43
  messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
44
  for item in history_transformer_format])
@@ -65,6 +56,8 @@ def predict(message, history):
65
  yield partial_message
66
 
67
 
 
 
68
  # Setting up the Gradio chat interface.
69
  gr.ChatInterface(predict,
70
  title="Tinyllama_chatBot",
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
  from threading import Thread
6
 
 
 
 
 
 
 
7
  # Loading the tokenizer and model from Hugging Face's model hub.
8
+ tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
9
+ model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
 
 
 
 
 
10
 
11
  # using CUDA for an optimal experience
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
23
  return False
24
 
25
 
26
+
27
  # Function to generate model predictions.
28
  def predict(message, history):
29
  history_transformer_format = history + [[message, ""]]
30
  stop = StopOnTokens()
31
 
32
+
33
  # Formatting the input for the model.
34
  messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
35
  for item in history_transformer_format])
 
56
  yield partial_message
57
 
58
 
59
+
60
+
61
  # Setting up the Gradio chat interface.
62
  gr.ChatInterface(predict,
63
  title="Tinyllama_chatBot",