harshraj commited on
Commit
d08b170
1 Parent(s): f7e15cc

Upload 2 files

Browse files

RAG with TinyLLaMA

Files changed (3) hide show
  1. .gitattributes +1 -0
  2. RAG_inference.py +144 -0
  3. hinglish_dataset_opt.csv +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ hinglish_dataset_opt.csv filter=lfs diff=lfs merge=lfs -text
RAG_inference.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
+ from threading import Thread
6
+
7
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
8
+ from llama_index.llms.huggingface import HuggingFaceLLM
9
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
10
+
11
+ document = SimpleDirectoryReader("sft/dataset").load_data()
12
+
13
+ from llama_index.core import PromptTemplate
14
+
15
+ system_prompt = "You are a QA bot. Given the questions answer it correctly."
16
+
17
+ query_wrapper_prompt = PromptTemplate("<|user|>:{query_str}\n<|assistant|>:")
18
+
19
+ llm = HuggingFaceLLM(
20
+ context_window=2048,
21
+ max_new_tokens=256,
22
+ generate_kwargs={"temperature":0.0, "do_sample":False},
23
+ system_prompt=system_prompt,
24
+ query_wrapper_prompt=query_wrapper_prompt,
25
+ tokenizer_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
26
+ model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
27
+ device_map="cuda",
28
+ model_kwargs={"torch_dtype":torch.bfloat16},
29
+ )
30
+
31
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
32
+
33
+ service_context = ServiceContext.from_defaults(
34
+ chunk_size=256,
35
+ llm=llm,
36
+ embed_model=embed_model
37
+ )
38
+
39
+ index = VectorStoreIndex.from_documents(document, service_context = service_context)
40
+
41
+ query_engine = index.as_query_engine()
42
+
43
+
44
+ # Defining a custom stopping criteria class for the model's text generation.
45
+ class StopOnTokens(StoppingCriteria):
46
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
47
+ stop_ids = [2] # IDs of tokens where the generation should stop.
48
+ for stop_id in stop_ids:
49
+ if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
50
+ return True
51
+ return False
52
+
53
+ # Function to generate model predictions.
54
+ def predict(message, history):
55
+ history_transformer_format = history + [[message, ""]]
56
+ stop = StopOnTokens()
57
+
58
+ # Formatting the input for the model.
59
+ messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
60
+ for item in history_transformer_format])
61
+
62
+ model_inputs = tokenizer([messages], return_tensors="pt").to(device)
63
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
64
+ generate_kwargs = dict(
65
+ model_inputs,
66
+ streamer=streamer,
67
+ max_new_tokens=1024,
68
+ do_sample=True,
69
+ top_p=0.95,
70
+ top_k=50,
71
+ temperature=0.5,
72
+ num_beams=1,
73
+ stopping_criteria=StoppingCriteriaList([stop])
74
+ )
75
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
76
+ t.start() # Starting the generation in a separate thread.
77
+ partial_message = ""
78
+ for new_token in streamer:
79
+ partial_message += new_token
80
+ if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
81
+ break
82
+ yield partial_message
83
+
84
+
85
+ def predict(input, history):
86
+ response = query_engine.query(input)
87
+ return str(response)
88
+
89
+ gr.ChatInterface(predict).launch(share=True)
90
+ # # Loading the tokenizer and model from Hugging Face's model hub.
91
+ # tokenizer = AutoTokenizer.from_pretrained("output/1T_FT_lr1e-5_ep5_top1_2024-03-04/checkpoint-575")
92
+ # model = AutoModelForCausalLM.from_pretrained("output/1T_FT_lr1e-5_ep5_top1_2024-03-04/checkpoint-575")
93
+
94
+ # # using CUDA for an optimal experience
95
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
96
+ # model = model.to(device)
97
+
98
+ # # Defining a custom stopping criteria class for the model's text generation.
99
+ # class StopOnTokens(StoppingCriteria):
100
+ # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
101
+ # stop_ids = [2] # IDs of tokens where the generation should stop.
102
+ # for stop_id in stop_ids:
103
+ # if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
104
+ # return True
105
+ # return False
106
+
107
+ # # Function to generate model predictions.
108
+ # def predict(message, history):
109
+ # history_transformer_format = history + [[message, ""]]
110
+ # stop = StopOnTokens()
111
+
112
+ # # Formatting the input for the model.
113
+ # messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
114
+ # for item in history_transformer_format])
115
+
116
+ # model_inputs = tokenizer([messages], return_tensors="pt").to(device)
117
+ # streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
118
+ # generate_kwargs = dict(
119
+ # model_inputs,
120
+ # streamer=streamer,
121
+ # max_new_tokens=1024,
122
+ # do_sample=True,
123
+ # top_p=0.95,
124
+ # top_k=50,
125
+ # temperature=0.5,
126
+ # num_beams=1,
127
+ # stopping_criteria=StoppingCriteriaList([stop])
128
+ # )
129
+ # t = Thread(target=model.generate, kwargs=generate_kwargs)
130
+ # t.start() # Starting the generation in a separate thread.
131
+ # partial_message = ""
132
+ # for new_token in streamer:
133
+ # partial_message += new_token
134
+ # if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
135
+ # break
136
+ # yield partial_message
137
+
138
+
139
+ # # Setting up the Gradio chat interface.
140
+ # gr.ChatInterface(predict,
141
+ # title="Tinyllama_chatBot",
142
+ # description="Ask Tiny llama any questions",
143
+ # examples=['How to cook a fish?', 'Who is the president of US now?']
144
+ # ).launch(share=True) # Launching the web interface.
hinglish_dataset_opt.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:909a0a8877746a55d2c34eb7d5e1786165a8abef077773816f7994bca76a61cc
3
+ size 22937637