arjunanand13 commited on
Commit
c62a6a1
·
verified ·
1 Parent(s): f573b86

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import cuda, bfloat16
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
5
+ from langchain.llms import HuggingFacePipeline
6
+ from langchain.vectorstores import FAISS
7
+ from langchain.chains import ConversationalRetrievalChain
8
+ import gradio as gr
9
+ from langchain.embeddings import HuggingFaceEmbeddings
10
+ from sentence_transformers import CrossEncoder
11
+
12
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
+
14
+ class StopOnTokens(StoppingCriteria):
15
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
16
+ for stop_ids in stop_token_ids:
17
+ if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
18
+ return True
19
+ return False
20
+
21
+ model_id = 'meta-llama/Meta-Llama-3-8B-Instruct'
22
+ device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
23
+
24
+ bnb_config = BitsAndBytesConfig(
25
+ load_in_4bit=True,
26
+ bnb_4bit_quant_type='nf4',
27
+ bnb_4bit_use_double_quant=True,
28
+ bnb_4bit_compute_dtype=bfloat16
29
+ )
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
32
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", token=HF_TOKEN, quantization_config=bnb_config)
33
+
34
+ stop_list = ['\nHuman:', '\n```\n']
35
+ stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
36
+ stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
37
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
38
+
39
+ generate_text = pipeline(
40
+ model=model,
41
+ tokenizer=tokenizer,
42
+ return_full_text=True,
43
+ task='text-generation',
44
+ stopping_criteria=stopping_criteria,
45
+ temperature=0.1,
46
+ max_new_tokens=512,
47
+ repetition_penalty=1.1
48
+ )
49
+
50
+ llm = HuggingFacePipeline(pipeline=generate_text)
51
+
52
+ """Load the stored FAISS index"""
53
+ try:
54
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda"})
55
+ vectorstore = FAISS.load_local('faiss_index', embeddings)
56
+ print("Loaded embeddings from FAISS Index successfully")
57
+ except ImportError as e:
58
+ print("FAISS could not be imported. Make sure FAISS is installed correctly.")
59
+ raise e
60
+
61
+ chain = ConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(), return_source_documents=True)
62
+
63
+ chat_history = []
64
+
65
+ reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
66
+
67
+ def format_prompt(query):
68
+ prompt = f"""
69
+ You are a knowledgeable assistant with access to a comprehensive database.
70
+ I need you to answer my question and provide related information in a specific format.
71
+ Here's what I need:
72
+ 1. A brief, general response to my question based on related answers retrieved.
73
+ 2. A JSON-formatted output containing:
74
+ - "question": The original question.
75
+ - "answer": The detailed answer.
76
+ - "related_questions": A list of related questions and their answers, each as a dictionary with the keys:
77
+ - "question": The related question.
78
+ - "answer": The related answer.
79
+ Here's my question:
80
+ {query}
81
+ Include a brief final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
82
+ """
83
+ return prompt
84
+
85
+ def qa_infer(query):
86
+ formatted_prompt = format_prompt(query)
87
+ results = chain({"question": formatted_prompt, "chat_history": chat_history})
88
+
89
+ documents = results['source_documents']
90
+ query_document_pairs = [[query, doc.page_content] for doc in documents]
91
+ scores = reranker.predict(query_document_pairs)
92
+
93
+ """Sort documents based on the re-ranker scores"""
94
+ ranked_docs = sorted(zip(scores, documents), key=lambda x: x[0], reverse=True)
95
+
96
+ """Extract the best document"""
97
+ best_doc = ranked_docs[0][1].page_content if ranked_docs else ""
98
+
99
+ return best_doc
100
+
101
+ EXAMPLES = ["How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
102
+ "Can BQ25896 support I2C interface?",
103
+ "Does TDA2 vout support bt656 8-bit mode?"]
104
+
105
+ demo = gr.Interface(fn=qa_infer, inputs="text", allow_flagging='never', examples=EXAMPLES, cache_examples=False, outputs="text")
106
+ demo.launch()