YANGSongsong commited on
Commit
144b240
β€’
1 Parent(s): 43b96b1

stable code 3b

Browse files
Files changed (1) hide show
  1. App.py +45 -0
App.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-3b", trust_remote_code=True)
6
+ model = AutoModelForCausalLM.from_pretrained(
7
+ "stabilityai/stable-code-3b",
8
+ trust_remote_code=True,
9
+ torch_dtype="auto"
10
+ )
11
+
12
+
13
+ class StopOnTokens(StoppingCriteria):
14
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
15
+ stop_ids = [0, 2]
16
+ for stop_id in stop_ids:
17
+ if input_ids[0][-1] == stop_id:
18
+ return True
19
+ return False
20
+
21
+
22
+ def chat(message, history):
23
+ stop = StopOnTokens()
24
+ history = history or []
25
+ inputs = tokenizer(message, return_tensors="pt").to(model.device)
26
+ print('generate')
27
+ tokens = model.generate(
28
+ **inputs,
29
+ max_new_tokens=4096,
30
+ temperature=0.2,
31
+ do_sample=True,
32
+ )
33
+ print('decode')
34
+ response = tokenizer.decode(tokens[0], skip_special_tokens=True)
35
+ history.append((message, response))
36
+ return history, history
37
+
38
+
39
+ iface = gr.Interface(
40
+ chat,
41
+ ["text", "state"],
42
+ ["chatbot", "state"],
43
+ allow_flagging="never"
44
+ )
45
+ iface.launch()