Illia56 commited on
Commit
152ba24
1 Parent(s): f0b62b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -6
app.py CHANGED
@@ -50,7 +50,22 @@ def transcribe_video(youtube_url: str, path: str) -> List[Document]:
50
  result = client.predict(youtube_url, "translate", True, fn_index=7)
51
  return [Document(page_content=result[1], metadata=dict(page=1))]
52
 
53
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  PATH = os.path.join(os.path.expanduser("~"), "Data")
56
 
@@ -100,6 +115,26 @@ prompt = PromptTemplate(
100
  input_variables=["context", "question"]
101
  )
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  # Check if a new YouTube URL is provided
104
  if st.session_state.youtube_url != st.session_state.doneYoutubeurl:
105
  st.session_state.setup_done = False
@@ -116,11 +151,7 @@ if st.session_state.youtube_url and not st.session_state.setup_done :
116
  retriever.search_kwargs['distance_metric'] = 'cos'
117
  retriever.search_kwargs['k'] = 4
118
  with st.status("Running RetrievalQA..."):
119
- llama_instance = HuggingFaceHub(
120
- model_kwargs={"max_length": 4096},
121
- repo_id="meta-llama/Llama-2-70b-chat-hf",
122
- huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
123
- )
124
  st.session_state.qa = RetrievalQA.from_chain_type(llm=llama_instance, chain_type="stuff", retriever=retriever,chain_type_kwargs={"prompt": prompt})
125
 
126
  st.session_state.doneYoutubeurl = st.session_state.youtube_url
 
50
  result = client.predict(youtube_url, "translate", True, fn_index=7)
51
  return [Document(page_content=result[1], metadata=dict(page=1))]
52
 
53
+ def predict(message: str, system_prompt: str = '', temperature: float = 0.7, max_new_tokens: int = 4096,
54
+ topp: float = 0.5, repetition_penalty: float = 1.2) -> Any:
55
+ """
56
+ Predict a response using a client.
57
+ """
58
+ client = Client("https://ysharma-explore-llamav2-with-tgi.hf.space/--replicas/xwjz8/")
59
+ response = client.predict(
60
+ message,
61
+ system_prompt,
62
+ temperature,
63
+ max_new_tokens,
64
+ topp,
65
+ repetition_penalty,
66
+ api_name="/chat_1"
67
+ )
68
+ return response
69
 
70
  PATH = os.path.join(os.path.expanduser("~"), "Data")
71
 
 
115
  input_variables=["context", "question"]
116
  )
117
 
118
+ class LlamaLLM(LLM):
119
+ """
120
+ Custom LLM class.
121
+ """
122
+
123
+ @property
124
+ def _llm_type(self) -> str:
125
+ return "custom"
126
+
127
+ def _call(self, prompt: str, stop: Optional[List[str]] = None,
128
+ run_manager: Optional[CallbackManagerForLLMRun] = None) -> str:
129
+ response = predict(prompt)
130
+ return response
131
+
132
+ @property
133
+ def _identifying_params(self) -> Mapping[str, Any]:
134
+ """Get the identifying parameters."""
135
+ return {}
136
+
137
+
138
  # Check if a new YouTube URL is provided
139
  if st.session_state.youtube_url != st.session_state.doneYoutubeurl:
140
  st.session_state.setup_done = False
 
151
  retriever.search_kwargs['distance_metric'] = 'cos'
152
  retriever.search_kwargs['k'] = 4
153
  with st.status("Running RetrievalQA..."):
154
+ llama_instance = LlamaLLM()
 
 
 
 
155
  st.session_state.qa = RetrievalQA.from_chain_type(llm=llama_instance, chain_type="stuff", retriever=retriever,chain_type_kwargs={"prompt": prompt})
156
 
157
  st.session_state.doneYoutubeurl = st.session_state.youtube_url