Update app.py
Browse files
app.py
CHANGED
|
@@ -26,6 +26,7 @@ from langchain_core.documents import Document
|
|
| 26 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 27 |
from sentence_transformers import SentenceTransformer
|
| 28 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
| 29 |
# from langchain.agents import create_tool_calling_agent
|
| 30 |
|
| 31 |
# (Keep Constants as is)
|
|
@@ -42,14 +43,13 @@ class Config(object):
|
|
| 42 |
self.random_state = 42
|
| 43 |
self.max_len = 256
|
| 44 |
self.reasoning_max_len = 256
|
| 45 |
-
self.temperature = 0.
|
| 46 |
self.repetition_penalty = 1.2
|
| 47 |
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 48 |
self.model_name = "Qwen/Qwen2.5-7B-Instruct"
|
| 49 |
-
# self.
|
| 50 |
# self.reasoning_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
|
| 51 |
# self.reasoning_model_name = "Qwen/Qwen2.5-7B-Instruct"
|
| 52 |
-
# self.reasoning_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
| 53 |
|
| 54 |
|
| 55 |
config = Config()
|
|
@@ -59,14 +59,14 @@ tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
|
| 59 |
model = AutoModelForCausalLM.from_pretrained(
|
| 60 |
config.model_name,
|
| 61 |
torch_dtype=torch.float16,
|
| 62 |
-
device_map=
|
| 63 |
)
|
| 64 |
|
| 65 |
# reasoning_tokenizer = AutoTokenizer.from_pretrained(config.reasoning_model_name)
|
| 66 |
# reasoning_model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
# config.reasoning_model_name,
|
| 68 |
# torch_dtype=torch.float16,
|
| 69 |
-
# device_map=
|
| 70 |
# )
|
| 71 |
|
| 72 |
def generate(prompt):
|
|
@@ -128,37 +128,6 @@ def reasoning_generate(prompt):
|
|
| 128 |
generated = outputs[0][inputs["input_ids"].shape[-1]:]
|
| 129 |
|
| 130 |
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
def reasoning_generate(prompt):
|
| 134 |
-
"""
|
| 135 |
-
Generate a text completion from a causal language model given a prompt.
|
| 136 |
-
|
| 137 |
-
Parameters
|
| 138 |
-
----------
|
| 139 |
-
prompt : str
|
| 140 |
-
Input text prompt used to condition the language model.
|
| 141 |
-
|
| 142 |
-
Returns
|
| 143 |
-
-------
|
| 144 |
-
str
|
| 145 |
-
The generated continuation text, decoded into a string with special
|
| 146 |
-
tokens removed and leading/trailing whitespace stripped.
|
| 147 |
-
|
| 148 |
-
"""
|
| 149 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 150 |
-
|
| 151 |
-
with torch.no_grad():
|
| 152 |
-
outputs = model.generate(
|
| 153 |
-
**inputs,
|
| 154 |
-
max_new_tokens=config.reasoning_max_len,
|
| 155 |
-
temperature=config.temperature,
|
| 156 |
-
repetition_penalty = config.repetition_penalty,
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
generated = outputs[0][inputs["input_ids"].shape[-1]:]
|
| 160 |
-
|
| 161 |
-
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
| 162 |
|
| 163 |
|
| 164 |
class Action(BaseModel):
|
|
@@ -880,45 +849,52 @@ def tool_executor(state: AgentState):
|
|
| 880 |
|
| 881 |
elif action.tool == "visit_webpage":
|
| 882 |
try:
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
if query_webpage_information_similarity_score > 0.65:
|
| 894 |
-
webpage_information_complete += webpage_result
|
| 895 |
-
webpage_information_complete += " \n "
|
| 896 |
-
webpage_information_complete += " \n "
|
| 897 |
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
|
| 903 |
|
| 904 |
-
webpage_results = visit_webpage_main(result)
|
| 905 |
-
webpage_result = " \n ".join(webpage_results)
|
| 906 |
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
|
| 910 |
-
query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
|
| 911 |
-
|
| 912 |
-
# logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
|
| 913 |
-
|
| 914 |
-
if query_webpage_information_similarity_score > 0.65:
|
| 915 |
-
webpage_information_complete += webpage_result
|
| 916 |
-
webpage_information_complete += " \n "
|
| 917 |
-
webpage_information_complete += " \n "
|
| 918 |
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 922 |
except:
|
| 923 |
pass
|
| 924 |
else:
|
|
@@ -985,9 +961,10 @@ class BasicAgent:
|
|
| 985 |
# if question == "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?" or question == "What is the first name of the only Malko Competition recipient from the 20th Century (after 1977) whose nationality on record is a country that no longer exists?":
|
| 986 |
# if question == "The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.":
|
| 987 |
# if question == "Where were the Vietnamese specimens described by Kuznetzov in Nedoshivina's 2010 paper eventually deposited? Just give me the city name without abbreviations.":
|
|
|
|
| 988 |
|
| 989 |
-
if question != "aalskdalsdh" and filename == "":
|
| 990 |
-
time.sleep(
|
| 991 |
|
| 992 |
|
| 993 |
state = {
|
|
|
|
| 26 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 27 |
from sentence_transformers import SentenceTransformer
|
| 28 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 29 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
| 30 |
# from langchain.agents import create_tool_calling_agent
|
| 31 |
|
| 32 |
# (Keep Constants as is)
|
|
|
|
| 43 |
self.random_state = 42
|
| 44 |
self.max_len = 256
|
| 45 |
self.reasoning_max_len = 256
|
| 46 |
+
self.temperature = 0.01
|
| 47 |
self.repetition_penalty = 1.2
|
| 48 |
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 49 |
self.model_name = "Qwen/Qwen2.5-7B-Instruct"
|
| 50 |
+
# self.reasoning_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
| 51 |
# self.reasoning_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
|
| 52 |
# self.reasoning_model_name = "Qwen/Qwen2.5-7B-Instruct"
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
config = Config()
|
|
|
|
| 59 |
model = AutoModelForCausalLM.from_pretrained(
|
| 60 |
config.model_name,
|
| 61 |
torch_dtype=torch.float16,
|
| 62 |
+
device_map=config.DEVICE
|
| 63 |
)
|
| 64 |
|
| 65 |
# reasoning_tokenizer = AutoTokenizer.from_pretrained(config.reasoning_model_name)
|
| 66 |
# reasoning_model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
# config.reasoning_model_name,
|
| 68 |
# torch_dtype=torch.float16,
|
| 69 |
+
# device_map=config.DEVICE
|
| 70 |
# )
|
| 71 |
|
| 72 |
def generate(prompt):
|
|
|
|
| 128 |
generated = outputs[0][inputs["input_ids"].shape[-1]:]
|
| 129 |
|
| 130 |
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
|
| 133 |
class Action(BaseModel):
|
|
|
|
| 849 |
|
| 850 |
elif action.tool == "visit_webpage":
|
| 851 |
try:
|
| 852 |
+
if "www.youtube.com" in str(action.args["url"]):
|
| 853 |
+
video_id = action.args["url"].split("www.youtube.com/watch?v=")[-1]
|
| 854 |
+
api = YouTubeTranscriptApi()
|
| 855 |
+
transcript = api.fetch(video_id)
|
| 856 |
+
texts = [x.text for x in transcript]
|
| 857 |
+
webpage_information_complete = " \n ".join([x.text for x in transcript])
|
| 858 |
+
else:
|
| 859 |
+
webpage_results = visit_webpage_wiki(result)
|
| 860 |
+
webpage_result = " \n ".join(webpage_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
|
| 862 |
+
# for webpage_result in webpage_results:
|
| 863 |
+
query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
|
| 864 |
+
webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
|
| 865 |
+
query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
|
| 866 |
+
|
| 867 |
+
# logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
|
| 868 |
+
|
| 869 |
+
if query_webpage_information_similarity_score > 0.65:
|
| 870 |
+
webpage_information_complete += webpage_result
|
| 871 |
+
webpage_information_complete += " \n "
|
| 872 |
+
webpage_information_complete += " \n "
|
| 873 |
+
|
| 874 |
+
if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
|
| 875 |
+
best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
|
| 876 |
+
best_webpage_information = webpage_result
|
| 877 |
|
| 878 |
|
|
|
|
|
|
|
| 879 |
|
| 880 |
+
webpage_results = visit_webpage_main(result)
|
| 881 |
+
webpage_result = " \n ".join(webpage_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 882 |
|
| 883 |
+
# for webpage_result in webpage_results:
|
| 884 |
+
query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
|
| 885 |
+
webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
|
| 886 |
+
query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
|
| 887 |
+
|
| 888 |
+
# logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
|
| 889 |
+
|
| 890 |
+
if query_webpage_information_similarity_score > 0.65:
|
| 891 |
+
webpage_information_complete += webpage_result
|
| 892 |
+
webpage_information_complete += " \n "
|
| 893 |
+
webpage_information_complete += " \n "
|
| 894 |
+
|
| 895 |
+
if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
|
| 896 |
+
best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
|
| 897 |
+
best_webpage_information = webpage_result
|
| 898 |
except:
|
| 899 |
pass
|
| 900 |
else:
|
|
|
|
| 961 |
# if question == "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?" or question == "What is the first name of the only Malko Competition recipient from the 20th Century (after 1977) whose nationality on record is a country that no longer exists?":
|
| 962 |
# if question == "The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.":
|
| 963 |
# if question == "Where were the Vietnamese specimens described by Kuznetzov in Nedoshivina's 2010 paper eventually deposited? Just give me the city name without abbreviations.":
|
| 964 |
+
if question == "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal\'c say in response to the question 'Isn\'t that hot?'":
|
| 965 |
|
| 966 |
+
# if question != "aalskdalsdh" and filename == "":
|
| 967 |
+
time.sleep(120)
|
| 968 |
|
| 969 |
|
| 970 |
state = {
|