Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -50,7 +50,22 @@ def transcribe_video(youtube_url: str, path: str) -> List[Document]:
|
|
50 |
result = client.predict(youtube_url, "translate", True, fn_index=7)
|
51 |
return [Document(page_content=result[1], metadata=dict(page=1))]
|
52 |
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
PATH = os.path.join(os.path.expanduser("~"), "Data")
|
56 |
|
@@ -100,6 +115,26 @@ prompt = PromptTemplate(
|
|
100 |
input_variables=["context", "question"]
|
101 |
)
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
# Check if a new YouTube URL is provided
|
104 |
if st.session_state.youtube_url != st.session_state.doneYoutubeurl:
|
105 |
st.session_state.setup_done = False
|
@@ -116,11 +151,7 @@ if st.session_state.youtube_url and not st.session_state.setup_done :
|
|
116 |
retriever.search_kwargs['distance_metric'] = 'cos'
|
117 |
retriever.search_kwargs['k'] = 4
|
118 |
with st.status("Running RetrievalQA..."):
|
119 |
-
llama_instance =
|
120 |
-
model_kwargs={"max_length": 4096},
|
121 |
-
repo_id="meta-llama/Llama-2-70b-chat-hf",
|
122 |
-
huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
|
123 |
-
)
|
124 |
st.session_state.qa = RetrievalQA.from_chain_type(llm=llama_instance, chain_type="stuff", retriever=retriever,chain_type_kwargs={"prompt": prompt})
|
125 |
|
126 |
st.session_state.doneYoutubeurl = st.session_state.youtube_url
|
|
|
50 |
result = client.predict(youtube_url, "translate", True, fn_index=7)
|
51 |
return [Document(page_content=result[1], metadata=dict(page=1))]
|
52 |
|
53 |
+
def predict(message: str, system_prompt: str = '', temperature: float = 0.7, max_new_tokens: int = 4096,
|
54 |
+
topp: float = 0.5, repetition_penalty: float = 1.2) -> Any:
|
55 |
+
"""
|
56 |
+
Predict a response using a client.
|
57 |
+
"""
|
58 |
+
client = Client("https://ysharma-explore-llamav2-with-tgi.hf.space/--replicas/xwjz8/")
|
59 |
+
response = client.predict(
|
60 |
+
message,
|
61 |
+
system_prompt,
|
62 |
+
temperature,
|
63 |
+
max_new_tokens,
|
64 |
+
topp,
|
65 |
+
repetition_penalty,
|
66 |
+
api_name="/chat_1"
|
67 |
+
)
|
68 |
+
return response
|
69 |
|
70 |
PATH = os.path.join(os.path.expanduser("~"), "Data")
|
71 |
|
|
|
115 |
input_variables=["context", "question"]
|
116 |
)
|
117 |
|
118 |
+
class LlamaLLM(LLM):
|
119 |
+
"""
|
120 |
+
Custom LLM class.
|
121 |
+
"""
|
122 |
+
|
123 |
+
@property
|
124 |
+
def _llm_type(self) -> str:
|
125 |
+
return "custom"
|
126 |
+
|
127 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None,
|
128 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None) -> str:
|
129 |
+
response = predict(prompt)
|
130 |
+
return response
|
131 |
+
|
132 |
+
@property
|
133 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
134 |
+
"""Get the identifying parameters."""
|
135 |
+
return {}
|
136 |
+
|
137 |
+
|
138 |
# Check if a new YouTube URL is provided
|
139 |
if st.session_state.youtube_url != st.session_state.doneYoutubeurl:
|
140 |
st.session_state.setup_done = False
|
|
|
151 |
retriever.search_kwargs['distance_metric'] = 'cos'
|
152 |
retriever.search_kwargs['k'] = 4
|
153 |
with st.status("Running RetrievalQA..."):
|
154 |
+
llama_instance = LlamaLLM()
|
|
|
|
|
|
|
|
|
155 |
st.session_state.qa = RetrievalQA.from_chain_type(llm=llama_instance, chain_type="stuff", retriever=retriever,chain_type_kwargs={"prompt": prompt})
|
156 |
|
157 |
st.session_state.doneYoutubeurl = st.session_state.youtube_url
|