UldisKK commited on
Commit
81e3d37
1 Parent(s): 874cdc5

add the rest of program

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py CHANGED
@@ -37,4 +37,64 @@ text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=51
37
  texts = text_splitter.split_documents(pdf_pages)
38
  st.write('total chunks from pages:', len(texts))
39
 
 
 
 
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  texts = text_splitter.split_documents(pdf_pages)
38
  st.write('total chunks from pages:', len(texts))
39
 
40
+ st.write('loading chunks into vector db')
41
+ model_name = "hkunlp/instructor-large"
42
+ hf_embeddings = HuggingFaceInstructEmbeddings(
43
+ model_name = model_name)
44
+ db = Chroma.from_documents(texts, hf_embeddings)
45
 
46
+ st.write('loading LLM')
47
+ model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
48
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
49
+
50
+ model_basename = "model"
51
+ use_triton = False
52
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
53
+ model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
54
+ model_basename=model_basename,
55
+ use_safetensors=True,
56
+ trust_remote_code=True,
57
+ device=DEVICE,
58
+ use_triton=use_triton,
59
+ quantize_config=None)
60
+
61
+ st.write('setting up the chain')
62
+ streamer = TextStreamer(tokenizer, skip_prompt = True, skip_special_tokens = True)
63
+ text_pipeline = pipeline(task = 'text-generation', model = model, tokenizer = tokenizer, streamer = streamer)
64
+ llm = HuggingFacePipeline(pipeline = text_pipeline)
65
+
66
+ def generate_prompt(prompt, sys_prompt):
67
+ return f"[INST] <<SYS>> {sys_prompt} <</SYS>> {prompt} [/INST]"
68
+
69
+ sys_prompt = "Use following piece of context to answer the question in less than 20 words"
70
+ template = generate_prompt(
71
+ """
72
+ {context}
73
+
74
+ Question : {question}
75
+ """
76
+ , sys_prompt)
77
+
78
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
79
+
80
+ chain_type_kwargs = {"prompt": prompt}
81
+ qa_chain = RetrievalQA.from_chain_type(
82
+ llm=llm,
83
+ chain_type="stuff",
84
+ retriever=db.as_retriever(search_kwargs={"k": 2}),
85
+ return_source_documents = True,
86
+ chain_type_kwargs=chain_type_kwargs,
87
+ )
88
+ st.write('READY!!!')
89
+
90
+ q1="what the author worked on ?"
91
+ q2="where did author study?"
92
+ q3="what author did ?"
93
+ result = qa_chain(q1)
94
+ st.write('question:', q1, 'result:', result)
95
+
96
+ result = qa_chain(q2)
97
+ st.write('question:', q2, 'result:', result)
98
+
99
+ result = qa_chain(q3)
100
+ st.write('question:', q3, 'result:', result)