Illia56 commited on
Commit
b37a5cd
1 Parent(s): 5c583d1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -72
app.py CHANGED
@@ -1,82 +1,150 @@
1
- import gradio as gr
 
 
 
2
  from gradio_client import Client
 
 
 
 
 
 
 
 
3
 
4
- title = "Llama2 70B Chatbot"
5
- description = """
6
- This Space demonstrates model [Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) by Meta, a Llama 2 model with 70B parameters fine-tuned for chat instructions.
7
- | Model | Llama2 | Llama2-hf | Llama2-chat | Llama2-chat-hf |
8
  |---|---|---|---|---|
9
  | 70B | [Link](https://huggingface.co/meta-llama/Llama-2-70b) | [Link](https://huggingface.co/meta-llama/Llama-2-70b-hf) | [Link](https://huggingface.co/meta-llama/Llama-2-70b-chat) | [Link](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) |
 
 
 
 
 
 
 
 
 
 
 
 
10
  """
11
- css = """.toast-wrap { display: none !important } """
12
- examples=[
13
- ['Hello there! How are you doing?'],
14
- ['Can you explain to me briefly what is Python programming language?'],
15
- ['Explain the plot of Cinderella in a sentence.'],
16
- ['How many hours does it take a man to eat a Helicopter?'],
17
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
18
- ]
19
-
20
-
21
- # Stream text
22
- def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=4096, top_p=0.6, repetition_penalty=1.0,):
23
-
 
 
 
 
 
 
24
  client = Client("https://ysharma-explore-llamav2-with-tgi.hf.space/")
25
- return client.predict(
26
- message, # str in 'Message' Textbox component
27
- system_prompt, # str in 'Optional system prompt' Textbox component
28
- temperature, # int | float (numeric value between 0.0 and 1.0)
29
- max_new_tokens, # int | float (numeric value between 0 and 4096)
30
- 0.3, # int | float (numeric value between 0.0 and 1)
31
- 1, # int | float (numeric value between 1.0 and 2.0)
32
- api_name="/chat_1"
33
  )
34
-
35
- additional_inputs=[
36
- gr.Textbox("", label="Optional system prompt"),
37
- gr.Slider(
38
- label="Temperature",
39
- value=0.9,
40
- minimum=0.0,
41
- maximum=1.0,
42
- step=0.05,
43
- interactive=True,
44
- info="Higher values produce more diverse outputs",
45
- ),
46
- gr.Slider(
47
- label="Max new tokens",
48
- value=4096,
49
- minimum=0,
50
- maximum=4096,
51
- step=64,
52
- interactive=True,
53
- info="The maximum numbers of new tokens",
54
- ),
55
- gr.Slider(
56
- label="Top-p (nucleus sampling)",
57
- value=0.6,
58
- minimum=0.0,
59
- maximum=1,
60
- step=0.05,
61
- interactive=True,
62
- info="Higher values sample more low-probability tokens",
63
- ),
64
- gr.Slider(
65
- label="Repetition penalty",
66
- value=1.2,
67
- minimum=1.0,
68
- maximum=2.0,
69
- step=0.05,
70
- interactive=True,
71
- info="Penalize repeated tokens",
72
- )
73
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
 
 
 
75
 
 
76
 
77
- # Gradio Demo
78
- with gr.Blocks(theme=gr.themes.Base()) as demo:
79
- gr.DuplicateButton()
80
- gr.ChatInterface(predict, title=title,additional_inputs=additional_inputs, description=description, css=css, examples=examples)
81
-
82
- demo.queue().launch(debug=True)
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import Any, List, Mapping, Optional
4
+
5
  from gradio_client import Client
6
+ from langchain.schema import Document
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain.vectorstores import FAISS
9
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
10
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
11
+ from langchain.llms.base import LLM
12
+ from langchain.chains import RetrievalQA
13
+ import streamlit as st
14
 
15
+ models = '''| Model | Llama2 | Llama2-hf | Llama2-chat | Llama2-chat-hf |
 
 
 
16
  |---|---|---|---|---|
17
  | 70B | [Link](https://huggingface.co/meta-llama/Llama-2-70b) | [Link](https://huggingface.co/meta-llama/Llama-2-70b-hf) | [Link](https://huggingface.co/meta-llama/Llama-2-70b-chat) | [Link](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) |
18
+ ---'''
19
+
20
+
21
+ DESCRIPTION = """
22
+ Welcome to the **YouTube Video Chatbot** powered by the state-of-the-art Llama-2-70b model. Here's what you can do:
23
+
24
+ - **Transcribe & Understand**: Provide any YouTube video URL, and our system will transcribe it. Our advanced NLP model will then understand the content, ready to answer your questions.
25
+ - **Ask Anything**: Based on the video's content, ask any question, and get instant, context-aware answers.
26
+ - **Deep Dive**: Our model doesn't just provide generic answers. It understands the context, nuances, and details from the video.
27
+ - **Safe & Private**: We value your privacy. The videos you provide are only used for transcription and are not stored or used for any other purpose.
28
+
29
+ To get started, simply paste a YouTube video URL in the sidebar and start chatting with the model about the video's content. Enjoy the experience!
30
  """
31
+
32
+ st.markdown(DESCRIPTION)
33
+
34
+
35
+ def transcribe_video(youtube_url: str, path: str) -> List[Document]:
36
+ """
37
+ Transcribe a video and return its content as a Document.
38
+ """
39
+ logging.info(f"Transcribing video: {youtube_url}")
40
+ client = Client("https://sanchit-gandhi-whisper-jax.hf.space/")
41
+ result = client.predict(youtube_url, "translate", True, fn_index=7)
42
+ return [Document(page_content=result[1], metadata=dict(page=1))]
43
+
44
+
45
+ def predict(message: str, system_prompt: str = '', temperature: float = 0.7, max_new_tokens: int = 4096,
46
+ topp: float = 0.5, repetition_penalty: float = 1.2) -> Any:
47
+ """
48
+ Predict a response using a client.
49
+ """
50
  client = Client("https://ysharma-explore-llamav2-with-tgi.hf.space/")
51
+ response = client.predict(
52
+ message,
53
+ system_prompt,
54
+ temperature,
55
+ max_new_tokens,
56
+ topp,
57
+ repetition_penalty,
58
+ api_name="/chat_1"
59
  )
60
+ return response
61
+
62
+
63
+ class LlamaLLM(LLM):
64
+ """
65
+ Custom LLM class.
66
+ """
67
+
68
+ @property
69
+ def _llm_type(self) -> str:
70
+ return "custom"
71
+
72
+ def _call(self, prompt: str, stop: Optional[List[str]] = None,
73
+ run_manager: Optional[CallbackManagerForLLMRun] = None) -> str:
74
+ response = predict(prompt)
75
+ return response
76
+
77
+ @property
78
+ def _identifying_params(self) -> Mapping[str, Any]:
79
+ """Get the identifying parameters."""
80
+ return {}
81
+
82
+ PATH = os.path.join(os.path.expanduser("~"), "Data")
83
+
84
+ def initialize_session_state():
85
+ if "youtube_url" not in st.session_state:
86
+ st.session_state.youtube_url = ""
87
+ if "setup_done" not in st.session_state: # Initialize the setup_done flag
88
+ st.session_state.setup_done = False
89
+ if "doneYoutubeurl" not in st.session_state:
90
+ st.session_state.doneYoutubeurl = ""
91
+
92
+ def sidebar():
93
+ with st.sidebar:
94
+ st.markdown(
95
+ "## How to use\n"
96
+ "1. Enter the YouTube Video URL below🔗\n"
97
+ )
98
+ st.session_state.youtube_url = st.text_input("YouTube Video URL:")
99
+
100
+ st.set_page_config(page_title="YouTube Video Chatbot",
101
+ layout="centered",
102
+ initial_sidebar_state="expanded")
103
+
104
+ st.title("YouTube Video Chatbot")
105
+ sidebar()
106
+ initialize_session_state()
107
+
108
+ # Check if a new YouTube URL is provided
109
+ if st.session_state.youtube_url != st.session_state.doneYoutubeurl:
110
+ st.session_state.setup_done = False
111
+
112
+ if st.session_state.youtube_url and not st.session_state.setup_done:
113
+ with st.status("Transcribing video..."):
114
+ data = transcribe_video(st.session_state.youtube_url, PATH)
115
+
116
+ with st.status("Running Embeddings..."):
117
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
118
+ docs = text_splitter.split_documents(data)
119
+
120
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-l6-v2")
121
+ docsearch = FAISS.from_documents(docs, embeddings)
122
+ retriever = docsearch.as_retriever()
123
+ retriever.search_kwargs['distance_metric'] = 'cos'
124
+ retriever.search_kwargs['k'] = 4
125
+ with st.status("Running RetrievalQA..."):
126
+ llama_instance = LlamaLLM()
127
+ st.session_state.qa = RetrievalQA.from_chain_type(llm=llama_instance, chain_type="stuff", retriever=retriever)
128
+ st.session_state.doneYoutubeurl = st.session_state.youtube_url
129
+
130
+ st.session_state.doneYoutubeurl = st.session_state.youtube_url
131
+ st.session_state.setup_done = True # Mark the setup as done for this URL
132
+
133
+ if "messages" not in st.session_state:
134
+ st.session_state.messages = []
135
 
136
+ for message in st.session_state.messages:
137
+ with st.chat_message(message["role"], avatar=("🧑‍💻" if message["role"] == 'human' else '🦙')):
138
+ st.markdown(message["content"])
139
 
140
+ textinput = st.chat_input("Ask LLama-2-70b anything about the video...")
141
 
142
+ if prompt := textinput:
143
+ st.chat_message("human",avatar = "🧑‍💻").markdown(prompt)
144
+ st.session_state.messages.append({"role": "human", "content": prompt})
145
+ with st.status("Requesting Client..."):
146
+ response = st.session_state.qa.run(prompt)
147
+ with st.chat_message("assistant", avatar='🦙'):
148
+ st.markdown(response)
149
+ # Add assistant response to chat history
150
+ st.session_state.messages.append({"role": "assistant", "content": response})