Daniel Marques commited on
Commit
cb776ef
1 Parent(s): e845546

feat: add history

Browse files
Files changed (2) hide show
  1. main.py +9 -7
  2. run_localGPT.py +1 -1
main.py CHANGED
@@ -12,6 +12,7 @@ import subprocess
12
  from langchain.chains import RetrievalQA
13
  from langchain.embeddings import HuggingFaceInstructEmbeddings
14
  from langchain.prompts import PromptTemplate
 
15
 
16
  # from langchain.embeddings import HuggingFaceEmbeddings
17
  from run_localGPT import load_model
@@ -55,11 +56,13 @@ Always answer in the most helpful and safe way possible.
55
  If you don't know the answer to a question, just say that you don't know, don't try to make up an answer, don't share false information.
56
  Use 15 sentences maximum. Keep the answer as concise as possible.
57
  Always say "thanks for asking!" at the end of the answer.
58
- {context}
59
  Question: {question}
60
- Helpful Answer:"""
61
 
62
- QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
 
 
63
 
64
  QA = RetrievalQA.from_chain_type(
65
  llm=LLM,
@@ -68,6 +71,7 @@ QA = RetrievalQA.from_chain_type(
68
  return_source_documents=SHOW_SOURCES,
69
  chain_type_kwargs={
70
  "prompt": QA_CHAIN_PROMPT,
 
71
  },
72
  )
73
 
@@ -118,7 +122,6 @@ def run_ingest_route():
118
  )
119
 
120
  RETRIEVER = DB.as_retriever()
121
- prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
122
 
123
  QA = RetrievalQA.from_chain_type(
124
  llm=LLM,
@@ -127,12 +130,11 @@ def run_ingest_route():
127
  return_source_documents=SHOW_SOURCES,
128
  chain_type_kwargs={
129
  "prompt": QA_CHAIN_PROMPT,
 
130
  },
131
  )
132
 
133
- response = "Script executed successfully: {}".format(result.stdout.decode("utf-8"))
134
-
135
- return {"response": response}
136
  except Exception as e:
137
  raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")
138
 
 
12
  from langchain.chains import RetrievalQA
13
  from langchain.embeddings import HuggingFaceInstructEmbeddings
14
  from langchain.prompts import PromptTemplate
15
+ from langchain.memory import ConversationBufferMemory
16
 
17
  # from langchain.embeddings import HuggingFaceEmbeddings
18
  from run_localGPT import load_model
 
56
  If you don't know the answer to a question, just say that you don't know, don't try to make up an answer, don't share false information.
57
  Use 15 sentences maximum. Keep the answer as concise as possible.
58
  Always say "thanks for asking!" at the end of the answer.
59
+ Context: {history} \n {context}
60
  Question: {question}
61
+ """
62
 
63
+ memory = ConversationBufferMemory(input_key="question", memory_key="history")
64
+
65
+ QA_CHAIN_PROMPT = PromptTemplate.from_template(input_variables=["history", "context", "question"], template=template)
66
 
67
  QA = RetrievalQA.from_chain_type(
68
  llm=LLM,
 
71
  return_source_documents=SHOW_SOURCES,
72
  chain_type_kwargs={
73
  "prompt": QA_CHAIN_PROMPT,
74
+ "memory": memory
75
  },
76
  )
77
 
 
122
  )
123
 
124
  RETRIEVER = DB.as_retriever()
 
125
 
126
  QA = RetrievalQA.from_chain_type(
127
  llm=LLM,
 
130
  return_source_documents=SHOW_SOURCES,
131
  chain_type_kwargs={
132
  "prompt": QA_CHAIN_PROMPT,
133
+ "memory": memory
134
  },
135
  )
136
 
137
+ return {"response": "The training was successfully completed"}
 
 
138
  except Exception as e:
139
  raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")
140
 
run_localGPT.py CHANGED
@@ -79,7 +79,7 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
79
 
80
  # Create a pipeline for text generation
81
 
82
- streamer = TextStreamer(tokenizer, skip_prompt=True)
83
 
84
  pipe = pipeline(
85
  "text-generation",
 
79
 
80
  # Create a pipeline for text generation
81
 
82
+ streamer = TextStreamer(tokenizer)
83
 
84
  pipe = pipeline(
85
  "text-generation",