arjunanand13 commited on
Commit
2497fee
1 Parent(s): eae3c36

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import cuda, bfloat16
4
+ from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig
5
+ from langchain.llms import HuggingFacePipeline
6
+ from langchain.vectorstores import FAISS
7
+ from langchain.chains import ConversationalRetrievalChain
8
+ import gradio as gr
9
+ from langchain.embeddings import HuggingFaceEmbeddings
10
+ from transformers import InferenceClient
11
+
12
+ # Load the Hugging Face token from environment
13
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
+
15
+ # Load the Mistral model and tokenizer
16
+ model_id = 'mistralai/Mistral-7B-Instruct-v0.3'
17
+ client = InferenceClient(model_id)
18
+
19
+ # Define stopping criteria
20
+ class StopOnTokens:
21
+ def __call__(self, input_ids, scores, **kwargs):
22
+ for stop_ids in stop_token_ids:
23
+ if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
24
+ return True
25
+ return False
26
+
27
+ # Define stopping criteria list
28
+ stop_list = ['\nHuman:', '\n```\n']
29
+ stop_token_ids = [client.tokenizer(x)['input_ids'] for x in stop_list]
30
+ stop_token_ids = [torch.LongTensor(x).to(cuda.current_device() if cuda.is_available() else 'cpu') for x in stop_token_ids]
31
+
32
+ # Create text generation pipeline
33
+ def generate(prompt, history, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0):
34
+ temperature = float(temperature)
35
+ if temperature < 1e-2:
36
+ temperature = 1e-2
37
+ top_p = float(top_p)
38
+
39
+ generate_kwargs = dict(
40
+ temperature=temperature,
41
+ max_new_tokens=max_new_tokens,
42
+ top_p=top_p,
43
+ repetition_penalty=repetition_penalty,
44
+ do_sample=True,
45
+ seed=42,
46
+ )
47
+
48
+ formatted_prompt = format_prompt(prompt, history, system_prompt)
49
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
50
+ output = ""
51
+
52
+ for response in stream:
53
+ output += response.token.text
54
+ yield output
55
+ return output
56
+
57
+ llm = HuggingFacePipeline(pipeline=generate)
58
+
59
+ # Load the stored FAISS index
60
+ try:
61
+ vectorstore = FAISS.load_local('faiss_index', HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda"}))
62
+ print("Loaded embedding successfully")
63
+ except ImportError as e:
64
+ print("FAISS could not be imported. Make sure FAISS is installed correctly.")
65
+ raise e
66
+
67
+ # Set up the Conversational Retrieval Chain
68
+ chain = ConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(), return_source_documents=True)
69
+
70
+ chat_history = []
71
+
72
+ def format_prompt(query):
73
+ prompt = f"""
74
+ You are a knowledgeable assistant with access to a comprehensive database.
75
+ I need you to answer my question and provide related information in a specific format.
76
+ Here's what I need:
77
+ 1. A brief, general response to my question based on related answers retrieved.
78
+ 2. A JSON-formatted output containing:
79
+ - "question": The original question.
80
+ - "answer": The detailed answer.
81
+ - "related_questions": A list of related questions and their answers, each as a dictionary with the keys:
82
+ - "question": The related question.
83
+ - "answer": The related answer.
84
+ Here's my question:
85
+ {query}
86
+ Include a brief final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
87
+ """
88
+ return prompt
89
+
90
+ def qa_infer(query):
91
+ formatted_prompt = format_prompt(query)
92
+ result = chain({"question": formatted_prompt, "chat_history": chat_history})
93
+ for doc in result['source_documents']:
94
+ print("-"*50)
95
+ print("Retrieved Document:", doc.page_content)
96
+ print("#"*100)
97
+ print(result['answer'])
98
+ return result['answer']
99
+
100
+ EXAMPLES = ["How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
101
+ "Can BQ25896 support I2C interface?",
102
+ "Does TDA2 vout support bt656 8-bit mode?"]
103
+
104
+ demo = gr.Interface(fn=qa_infer, inputs="text", allow_flagging='never', examples=EXAMPLES, cache_examples=False, outputs="text")
105
+ demo.launch()