SivaResearch commited on
Commit
21621e1
1 Parent(s): e4b1f2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -4
app.py CHANGED
@@ -1,7 +1,64 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
+ from threading import Thread
6
 
 
 
7
 
8
+
9
+ # Loading the tokenizer and model from Hugging Face's model hub.
10
+ tokenizer = AutoTokenizer.from_pretrained("SivaResearch/tinyllama-Siv-v2")
11
+ model = AutoModelForCausalLM.from_pretrained("SivaResearch/tinyllama-Siv-v2")
12
+
13
+ # using CUDA for an optimal experience
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ model = model.to(device)
16
+
17
+
18
+ # Defining a custom stopping criteria class for the model's text generation.
19
+ class StopOnTokens(StoppingCriteria):
20
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
21
+ stop_ids = [2] # IDs of tokens where the generation should stop.
22
+ for stop_id in stop_ids:
23
+ if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
24
+ return True
25
+ return False
26
+
27
+
28
+ # Function to generate model predictions.
29
+ def predict(message, history):
30
+ history_transformer_format = history + [[message, ""]]
31
+ stop = StopOnTokens()
32
+
33
+ # Formatting the input for the model.
34
+ messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
35
+ for item in history_transformer_format])
36
+ model_inputs = tokenizer([messages], return_tensors="pt").to(device)
37
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
38
+ generate_kwargs = dict(
39
+ model_inputs,
40
+ streamer=streamer,
41
+ max_new_tokens=1024,
42
+ do_sample=True,
43
+ top_p=0.95,
44
+ top_k=50,
45
+ temperature=0.7,
46
+ num_beams=1,
47
+ stopping_criteria=StoppingCriteriaList([stop])
48
+ )
49
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
50
+ t.start() # Starting the generation in a separate thread.
51
+ partial_message = ""
52
+ for new_token in streamer:
53
+ partial_message += new_token
54
+ if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
55
+ break
56
+ yield partial_message
57
+
58
+
59
+ # Setting up the Gradio chat interface.
60
+ gr.ChatInterface(predict,
61
+ title="Tinyllama_chatBot",
62
+ description="Ask Tiny llama any questions",
63
+ examples=['How to cook a fish?', 'Who is the president of US now?']
64
+ ).launch() # Launching the web interface.