arjunanand13 commited on
Commit
e44519b
1 Parent(s): b24351f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import transformers
3
+ import accelerate
4
+ import einops
5
+ import langchain
6
+ import xformers
7
+ import bitsandbytes
8
+ import sentence_transformers
9
+ import huggingface_hub
10
+ import torch
11
+ from torch import cuda, bfloat16
12
+ from transformers import StoppingCriteria, StoppingCriteriaList
13
+ from langchain.llms import HuggingFacePipeline
14
+ from langchain.document_loaders import TextLoader, DirectoryLoader
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain.embeddings import HuggingFaceEmbeddings
17
+ from langchain.vectorstores import FAISS
18
+ from langchain.chains import ConversationalRetrievalChain
19
+
20
+ # Login to Hugging Face using a token
21
+ huggingface_hub.login(token=token)
22
+
23
+ """
24
+ Loading of the LLama3 model
25
+ """
26
+
27
+ model_id = 'meta-llama/Meta-Llama-3-8B-Instruct'
28
+ device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
29
+
30
+ # set quantization configuration to load large model with less GPU memory
31
+ # this requires the `bitsandbytes` library
32
+ bnb_config = transformers.BitsAndBytesConfig(
33
+ load_in_4bit=True,
34
+ bnb_4bit_quant_type='nf4',
35
+ bnb_4bit_use_double_quant=True,
36
+ bnb_4bit_compute_dtype=bfloat16
37
+ )
38
+
39
+ model_config = transformers.AutoConfig.from_pretrained(
40
+ model_id,
41
+ )
42
+
43
+ model = transformers.AutoModelForCausalLM.from_pretrained(
44
+ model_id,
45
+ trust_remote_code=True,
46
+ config=model_config,
47
+ quantization_config=bnb_config,
48
+ device_map='auto',
49
+ )
50
+
51
+ # enable evaluation mode to allow model inference
52
+ model.eval()
53
+ print(f"Model loaded on {device}")
54
+
55
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
56
+ model_id,
57
+ )
58
+
59
+ """
60
+ Setting up the stop list to define stopping criteria.
61
+ """
62
+
63
+ stop_list = ['\nHuman:', '\n```\n']
64
+
65
+ stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
66
+ stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
67
+
68
+
69
+ # define custom stopping criteria object
70
+ class StopOnTokens(StoppingCriteria):
71
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
72
+ for stop_ids in stop_token_ids:
73
+ if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
74
+ return True
75
+ return False
76
+
77
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
78
+
79
+
80
+ generate_text = transformers.pipeline(
81
+ model=model,
82
+ tokenizer=tokenizer,
83
+ return_full_text=True, # langchain expects the full text
84
+ task='text-generation',
85
+ # we pass model parameters here too
86
+ stopping_criteria=stopping_criteria, # without this model rambles during chat
87
+ temperature=0.1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
88
+ max_new_tokens=512, # max number of tokens to generate in the output
89
+ repetition_penalty=1.1 # without this output begins repeating
90
+ )
91
+
92
+ llm = HuggingFacePipeline(pipeline=generate_text)
93
+
94
+ loader = DirectoryLoader('data/text/', loader_cls=TextLoader)
95
+ documents = loader.load()
96
+ print('len of documents are',len(documents))
97
+
98
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
99
+ all_splits = text_splitter.split_documents(documents)
100
+
101
+ model_name = "sentence-transformers/all-mpnet-base-v2"
102
+ model_kwargs = {"device": "cuda"}
103
+
104
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
105
+
106
+ # storing embeddings in the vector store
107
+ vectorstore = FAISS.from_documents(all_splits, embeddings)
108
+
109
+ chain = ConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(), return_source_documents=True)
110
+
111
+ chat_history = []
112
+ def qa_infer(query):
113
+ result = chain({"question": query, "chat_history": chat_history})
114
+ print(result['answer'])
115
+ return result['answer']
116
+
117
+ # query = "What` is the best TS pin configuration for BQ24040 in normal battery charge mode"
118
+ # qa_infer(query)
119
+
120
+ EXAMPLES = ["What is the best TS pin configuration for BQ24040 in normal battery charge mode",
121
+ "Can BQ25896 support I2C interface?",
122
+ "Can you please provide me with Gerber/CAD file for UCC2897A"]
123
+
124
+ demo = gr.Interface(fn=qa_infer, inputs="text",allow_flagging='never', examples=EXAMPLES,
125
+ cache_examples=False,outputs="text")
126
+
127
+ # launch the app!
128
+ #demo.launch(enable_queue = True,share=True)
129
+ #demo.queue(default_enabled=True).launch(debug=True,share=True)
130
+ demo.launch()