Sandiago21 commited on
Commit
b8579ec
·
verified ·
1 Parent(s): d66c78b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -72
app.py CHANGED
@@ -26,6 +26,7 @@ from langchain_core.documents import Document
26
  from langgraph.prebuilt import ToolNode, tools_condition
27
  from sentence_transformers import SentenceTransformer
28
  from sklearn.metrics.pairwise import cosine_similarity
 
29
  # from langchain.agents import create_tool_calling_agent
30
 
31
  # (Keep Constants as is)
@@ -42,14 +43,13 @@ class Config(object):
42
  self.random_state = 42
43
  self.max_len = 256
44
  self.reasoning_max_len = 256
45
- self.temperature = 0.1
46
  self.repetition_penalty = 1.2
47
  self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
48
  self.model_name = "Qwen/Qwen2.5-7B-Instruct"
49
- # self.model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
50
  # self.reasoning_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
51
  # self.reasoning_model_name = "Qwen/Qwen2.5-7B-Instruct"
52
- # self.reasoning_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
53
 
54
 
55
  config = Config()
@@ -59,14 +59,14 @@ tokenizer = AutoTokenizer.from_pretrained(config.model_name)
59
  model = AutoModelForCausalLM.from_pretrained(
60
  config.model_name,
61
  torch_dtype=torch.float16,
62
- device_map="auto"
63
  )
64
 
65
  # reasoning_tokenizer = AutoTokenizer.from_pretrained(config.reasoning_model_name)
66
  # reasoning_model = AutoModelForCausalLM.from_pretrained(
67
  # config.reasoning_model_name,
68
  # torch_dtype=torch.float16,
69
- # device_map="auto"
70
  # )
71
 
72
  def generate(prompt):
@@ -128,37 +128,6 @@ def reasoning_generate(prompt):
128
  generated = outputs[0][inputs["input_ids"].shape[-1]:]
129
 
130
  return tokenizer.decode(generated, skip_special_tokens=True).strip()
131
-
132
-
133
- def reasoning_generate(prompt):
134
- """
135
- Generate a text completion from a causal language model given a prompt.
136
-
137
- Parameters
138
- ----------
139
- prompt : str
140
- Input text prompt used to condition the language model.
141
-
142
- Returns
143
- -------
144
- str
145
- The generated continuation text, decoded into a string with special
146
- tokens removed and leading/trailing whitespace stripped.
147
-
148
- """
149
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
150
-
151
- with torch.no_grad():
152
- outputs = model.generate(
153
- **inputs,
154
- max_new_tokens=config.reasoning_max_len,
155
- temperature=config.temperature,
156
- repetition_penalty = config.repetition_penalty,
157
- )
158
-
159
- generated = outputs[0][inputs["input_ids"].shape[-1]:]
160
-
161
- return tokenizer.decode(generated, skip_special_tokens=True).strip()
162
 
163
 
164
  class Action(BaseModel):
@@ -880,45 +849,52 @@ def tool_executor(state: AgentState):
880
 
881
  elif action.tool == "visit_webpage":
882
  try:
883
- webpage_results = visit_webpage_wiki(result)
884
- webpage_result = " \n ".join(webpage_results)
885
-
886
- # for webpage_result in webpage_results:
887
- query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
888
- webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
889
- query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
890
-
891
- # logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
892
-
893
- if query_webpage_information_similarity_score > 0.65:
894
- webpage_information_complete += webpage_result
895
- webpage_information_complete += " \n "
896
- webpage_information_complete += " \n "
897
 
898
- if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
899
- best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
900
- best_webpage_information = webpage_result
901
-
 
 
 
 
 
 
 
 
 
 
 
902
 
903
 
904
- webpage_results = visit_webpage_main(result)
905
- webpage_result = " \n ".join(webpage_results)
906
 
907
- # for webpage_result in webpage_results:
908
- query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
909
- webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
910
- query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
911
-
912
- # logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
913
-
914
- if query_webpage_information_similarity_score > 0.65:
915
- webpage_information_complete += webpage_result
916
- webpage_information_complete += " \n "
917
- webpage_information_complete += " \n "
918
 
919
- if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
920
- best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
921
- best_webpage_information = webpage_result
 
 
 
 
 
 
 
 
 
 
 
 
922
  except:
923
  pass
924
  else:
@@ -985,9 +961,10 @@ class BasicAgent:
985
  # if question == "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?" or question == "What is the first name of the only Malko Competition recipient from the 20th Century (after 1977) whose nationality on record is a country that no longer exists?":
986
  # if question == "The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.":
987
  # if question == "Where were the Vietnamese specimens described by Kuznetzov in Nedoshivina's 2010 paper eventually deposited? Just give me the city name without abbreviations.":
 
988
 
989
- if question != "aalskdalsdh" and filename == "":
990
- time.sleep(60)
991
 
992
 
993
  state = {
 
26
  from langgraph.prebuilt import ToolNode, tools_condition
27
  from sentence_transformers import SentenceTransformer
28
  from sklearn.metrics.pairwise import cosine_similarity
29
+ from youtube_transcript_api import YouTubeTranscriptApi
30
  # from langchain.agents import create_tool_calling_agent
31
 
32
  # (Keep Constants as is)
 
43
  self.random_state = 42
44
  self.max_len = 256
45
  self.reasoning_max_len = 256
46
+ self.temperature = 0.01
47
  self.repetition_penalty = 1.2
48
  self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
49
  self.model_name = "Qwen/Qwen2.5-7B-Instruct"
50
+ # self.reasoning_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
51
  # self.reasoning_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
52
  # self.reasoning_model_name = "Qwen/Qwen2.5-7B-Instruct"
 
53
 
54
 
55
  config = Config()
 
59
  model = AutoModelForCausalLM.from_pretrained(
60
  config.model_name,
61
  torch_dtype=torch.float16,
62
+ device_map=config.DEVICE
63
  )
64
 
65
  # reasoning_tokenizer = AutoTokenizer.from_pretrained(config.reasoning_model_name)
66
  # reasoning_model = AutoModelForCausalLM.from_pretrained(
67
  # config.reasoning_model_name,
68
  # torch_dtype=torch.float16,
69
+ # device_map=config.DEVICE
70
  # )
71
 
72
  def generate(prompt):
 
128
  generated = outputs[0][inputs["input_ids"].shape[-1]:]
129
 
130
  return tokenizer.decode(generated, skip_special_tokens=True).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
 
133
  class Action(BaseModel):
 
849
 
850
  elif action.tool == "visit_webpage":
851
  try:
852
+ if "www.youtube.com" in str(action.args["url"]):
853
+ video_id = action.args["url"].split("www.youtube.com/watch?v=")[-1]
854
+ api = YouTubeTranscriptApi()
855
+ transcript = api.fetch(video_id)
856
+ texts = [x.text for x in transcript]
857
+ webpage_information_complete = " \n ".join([x.text for x in transcript])
858
+ else:
859
+ webpage_results = visit_webpage_wiki(result)
860
+ webpage_result = " \n ".join(webpage_results)
 
 
 
 
 
861
 
862
+ # for webpage_result in webpage_results:
863
+ query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
864
+ webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
865
+ query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
866
+
867
+ # logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
868
+
869
+ if query_webpage_information_similarity_score > 0.65:
870
+ webpage_information_complete += webpage_result
871
+ webpage_information_complete += " \n "
872
+ webpage_information_complete += " \n "
873
+
874
+ if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
875
+ best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
876
+ best_webpage_information = webpage_result
877
 
878
 
 
 
879
 
880
+ webpage_results = visit_webpage_main(result)
881
+ webpage_result = " \n ".join(webpage_results)
 
 
 
 
 
 
 
 
 
882
 
883
+ # for webpage_result in webpage_results:
884
+ query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
885
+ webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
886
+ query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
887
+
888
+ # logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
889
+
890
+ if query_webpage_information_similarity_score > 0.65:
891
+ webpage_information_complete += webpage_result
892
+ webpage_information_complete += " \n "
893
+ webpage_information_complete += " \n "
894
+
895
+ if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
896
+ best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
897
+ best_webpage_information = webpage_result
898
  except:
899
  pass
900
  else:
 
961
  # if question == "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?" or question == "What is the first name of the only Malko Competition recipient from the 20th Century (after 1977) whose nationality on record is a country that no longer exists?":
962
  # if question == "The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.":
963
  # if question == "Where were the Vietnamese specimens described by Kuznetzov in Nedoshivina's 2010 paper eventually deposited? Just give me the city name without abbreviations.":
964
+ if question == "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal\'c say in response to the question 'Isn\'t that hot?'":
965
 
966
+ # if question != "aalskdalsdh" and filename == "":
967
+ time.sleep(120)
968
 
969
 
970
  state = {