Chris Alexiuk commited on
Commit
3391cce
1 Parent(s): 643f5c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -47
app.py CHANGED
@@ -56,55 +56,12 @@ async def init():
56
  # docsearch = await cl.make_async(Chroma.from_documents)(pdf_data, embeddings)
57
  docsearch = Chroma.from_documents(pdf_data, embeddings)
58
 
59
- # custom SageMaker Model
60
- class Llama2SageMaker(LLM):
61
- max_new_tokens: int = 256
62
- top_p: float = 0.9
63
- temperature: float = 0.1
64
-
65
- @property
66
- def _llm_type(self) -> str:
67
- return "Llama2SageMaker"
68
-
69
- def _call(
70
- self,
71
- prompt: str,
72
- stop: Optional[List[str]] = None,
73
- run_manager: Optional[CallbackManagerForLLMRun] = None,
74
- ) -> str:
75
- if stop is not None:
76
- raise ValueError("stop kwargs are not permitted.")
77
-
78
- json_body = {
79
- "inputs" : [
80
- [{"role" : "user", "content" : prompt}]
81
- ],
82
- "parameters" : {
83
- "max_new_tokens" : self.max_new_tokens,
84
- "top_p" : self.top_p,
85
- "temperature" : self.temperature
86
- }
87
- }
88
-
89
- response = requests.post(model_api_gateway, json=json_body)
90
-
91
- return response.json()[0]["generation"]["content"]
92
-
93
- @property
94
- def _identifying_params(self) -> Mapping[str, Any]:
95
- """Get the identifying parameters."""
96
- return {
97
- "max_new_tokens" : self.max_new_tokens,
98
- "top_p" : self.top_p,
99
- "temperature" : self.temperature
100
- }
101
-
102
- # set our llm to the custom Llama2SageMaker endpoint model
103
- llm = Llama2SageMaker()
104
-
105
  # Create a chain that uses the Chroma vector store
106
  chain = RetrievalQAWithSourcesChain.from_chain_type(
107
- llm=llm,
 
 
 
108
  chain_type="stuff",
109
  retriever=docsearch.as_retriever(),
110
  return_source_documents=True,
 
56
  # docsearch = await cl.make_async(Chroma.from_documents)(pdf_data, embeddings)
57
  docsearch = Chroma.from_documents(pdf_data, embeddings)
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # Create a chain that uses the Chroma vector store
60
  chain = RetrievalQAWithSourcesChain.from_chain_type(
61
+ ChatOpenAI(
62
+ model="gpt-3.5-turbo",
63
+ temperature=0.0
64
+ ),
65
  chain_type="stuff",
66
  retriever=docsearch.as_retriever(),
67
  return_source_documents=True,