AhmadMustafa commited on
Commit
37b73c5
·
1 Parent(s): 768a878

update: add streamimng

Browse files
Files changed (1) hide show
  1. app.py +138 -101
app.py CHANGED
@@ -411,11 +411,21 @@ Rank topics based on their potential virality and engagement for social media cl
411
  },
412
  {"role": "user", "content": prompt},
413
  ],
 
414
  )
415
- return completion.choices[0].message.content
 
 
 
 
 
 
 
 
 
416
  except Exception as e:
417
  print(f"Error in initial analysis: {str(e)}")
418
- return "An error occurred during initial analysis. Please check your API key and file path."
419
 
420
 
421
  def chat(
@@ -498,63 +508,77 @@ In the URL, make sure that after RSID there is ? and then rest of the fields are
498
  messages.append({"role": "user", "content": message})
499
 
500
  completion = client.chat.completions.create(
501
- model="gpt-4o-mini",
502
- messages=messages,
503
- tools=tools,
504
  )
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
- response = completion.choices[0].message
507
- messages.append(response)
508
-
509
- if response.tool_calls:
510
- if response.tool_calls[0].function.name == "correct_speaker_name_with_url":
511
- args = eval(response.tool_calls[0].function.arguments)
512
- url = args.get("url", None)
513
- if url:
514
- transcript_processor.correct_speaker_mapping_with_agenda(url)
515
- corrected_speaker_mapping = transcript_processor.speaker_mapping
516
- function_call_result_message = {
517
- "role": "tool",
518
- "content": json.dumps(
519
- {
520
- "speaker_mapping": f"Corrected Speaker Mapping is: {corrected_speaker_mapping}\n, All speakers should be addressed via this mapping. The next message should be comparing old speaker names (do not use spk_0, spk_1, use the old names) and corrected speaker names.",
 
 
 
 
521
  }
522
- ),
523
- "name": response.tool_calls[0].function.name,
524
- "tool_call_id": response.tool_calls[0].id,
525
- }
526
-
527
- # messages.append(response.choices[0]["message"])
528
- messages.append(function_call_result_message)
529
- completion_payload = {"model": "gpt-4o-mini", "messages": messages}
530
- # print("messages", messages[3])
531
- response = client.chat.completions.create(**completion_payload)
532
- # print("no error here")
533
-
534
- return response.choices[0].message.content
535
-
536
- else:
537
- return "No URL provided for correcting speaker names."
538
- elif response.tool_calls[0].function.name == "correct_call_type":
539
- args = eval(response.tool_calls[0].function.arguments)
540
- call_type = args.get("call_type", None)
541
- if call_type:
542
- analysis = get_initial_analysis(
543
- transcript_processor, call_type, rsid, origin, call_type
544
- )
545
- return analysis
546
- else:
547
- return "No call type provided for correction."
548
- else:
549
- return "No tool for this request"
550
- return response.content
551
 
552
  except Exception as e:
553
  print(f"Unexpected error in chat: {str(e)}")
554
  import traceback
555
 
556
  print(f"Traceback: {traceback.format_exc()}")
557
- return "Sorry, there was an error processing your request."
558
 
559
 
560
  def create_chat_interface():
@@ -688,18 +712,62 @@ def create_chat_interface():
688
  iframe_html = "<iframe id='link-frame'></iframe>"
689
  gr.HTML(value=iframe_html) # Add iframe to the UI
690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
  def on_app_load(request: gr.Request):
692
  cid = request.query_params.get("cid", None)
693
  rsid = request.query_params.get("rsid", None)
694
  origin = request.query_params.get("origin", None)
695
  ct = request.query_params.get("ct", None)
696
  turl = request.query_params.get("turl", None)
 
697
  required_params = ["cid", "rsid", "origin", "ct", "turl"]
698
  missing_params = [
699
  param
700
  for param in required_params
701
  if request.query_params.get(param) is None
702
  ]
 
703
  if missing_params:
704
  error_message = (
705
  f"Missing required parameters: {', '.join(missing_params)}"
@@ -715,33 +783,16 @@ def create_chat_interface():
715
  None,
716
  ]
717
 
718
- # if any param is missing, return error
719
- if not cid or not rsid or not origin or not ct or not turl:
720
- error_message = "Error processing"
721
- chatbot_value = [(None, error_message)]
722
- return [
723
- chatbot_value,
724
- None,
725
- None,
726
- None,
727
- None,
728
- None,
729
- None,
730
- ]
731
-
732
  try:
733
  transcript_data = get_transcript_for_url(turl)
734
  transcript_processor = TranscriptProcessor(
735
  transcript_data=transcript_data
736
  )
737
- initial_analysis = get_initial_analysis(
738
- transcript_processor, cid, rsid, origin, ct
739
- )
740
 
741
- chatbot_value = [
742
- (None, initial_analysis)
743
- ] # initialized with initial analysis and assistant is None
744
 
 
745
  return [
746
  chatbot_value,
747
  transcript_processor,
@@ -764,6 +815,19 @@ def create_chat_interface():
764
  None,
765
  ]
766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767
  demo.load(
768
  on_app_load,
769
  inputs=None,
@@ -776,36 +840,9 @@ def create_chat_interface():
776
  ct_state,
777
  turl_state,
778
  ],
779
- )
780
-
781
- def respond(
782
- message: str,
783
- chat_history: List,
784
- transcript_processor,
785
- cid,
786
- rsid,
787
- origin,
788
- ct,
789
- ):
790
- if not transcript_processor:
791
- bot_message = "Transcript processor not initialized."
792
- else:
793
- bot_message = chat(
794
- message,
795
- chat_history,
796
- transcript_processor,
797
- cid,
798
- rsid,
799
- origin,
800
- ct,
801
- )
802
- chat_history.append((message, bot_message))
803
- return "", chat_history
804
-
805
- msg.submit(
806
- respond,
807
- [
808
- msg,
809
  chatbot,
810
  transcript_processor_state,
811
  call_id_state,
@@ -813,7 +850,7 @@ def create_chat_interface():
813
  origin_state,
814
  ct_state,
815
  ],
816
- [msg, chatbot],
817
  )
818
 
819
  return demo
 
411
  },
412
  {"role": "user", "content": prompt},
413
  ],
414
+ stream=True,
415
  )
416
+
417
+ collected_messages = []
418
+ # Iterate through the stream
419
+ for chunk in completion:
420
+ if chunk.choices[0].delta.content is not None:
421
+ chunk_message = chunk.choices[0].delta.content
422
+ collected_messages.append(chunk_message)
423
+ # Yield the accumulated message so far
424
+ yield "".join(collected_messages)
425
+
426
  except Exception as e:
427
  print(f"Error in initial analysis: {str(e)}")
428
+ yield "An error occurred during initial analysis. Please check your API key and file path."
429
 
430
 
431
  def chat(
 
508
  messages.append({"role": "user", "content": message})
509
 
510
  completion = client.chat.completions.create(
511
+ model="gpt-4o-mini", messages=messages, tools=tools, stream=True
 
 
512
  )
513
+ collected_messages = []
514
+ tool_calls_detected = False
515
+
516
+ for chunk in completion:
517
+ if chunk.choices[0].delta.tool_calls:
518
+ tool_calls_detected = True
519
+ # Handle tool calls without streaming
520
+ response = client.chat.completions.create(
521
+ model="gpt-4o-mini",
522
+ messages=messages,
523
+ tools=tools,
524
+ )
525
 
526
+ if response.choices[0].message.tool_calls:
527
+ tool_call = response.choices[0].message.tool_calls[0]
528
+ if tool_call.function.name == "correct_speaker_name_with_url":
529
+ args = eval(tool_call.function.arguments)
530
+ url = args.get("url", None)
531
+ if url:
532
+ transcript_processor.correct_speaker_mapping_with_agenda(
533
+ url
534
+ )
535
+ corrected_speaker_mapping = (
536
+ transcript_processor.speaker_mapping
537
+ )
538
+ function_call_result_message = {
539
+ "role": "tool",
540
+ "content": json.dumps(
541
+ {"speaker_mapping": f"Corrected Speaker Mapping..."}
542
+ ),
543
+ "name": tool_call.function.name,
544
+ "tool_call_id": tool_call.id,
545
  }
546
+ messages.append(function_call_result_message)
547
+
548
+ # Get final response after tool call
549
+ final_response = client.chat.completions.create(
550
+ model="gpt-4o-mini", messages=messages, stream=True
551
+ )
552
+
553
+ # Stream the final response
554
+ for final_chunk in final_response:
555
+ if final_chunk.choices[0].delta.content:
556
+ yield final_chunk.choices[0].delta.content
557
+ return
558
+
559
+ elif tool_call.function.name == "correct_call_type":
560
+ args = eval(tool_call.function.arguments)
561
+ call_type = args.get("call_type", None)
562
+ if call_type:
563
+ # Stream the analysis for corrected call type
564
+ for content in get_initial_analysis(
565
+ transcript_processor, call_type, rsid, origin, call_type
566
+ ):
567
+ yield content
568
+ return
569
+ break # Exit streaming loop if tool calls detected
570
+
571
+ if not tool_calls_detected and chunk.choices[0].delta.content is not None:
572
+ chunk_message = chunk.choices[0].delta.content
573
+ collected_messages.append(chunk_message)
574
+ yield "".join(collected_messages)
575
 
576
  except Exception as e:
577
  print(f"Unexpected error in chat: {str(e)}")
578
  import traceback
579
 
580
  print(f"Traceback: {traceback.format_exc()}")
581
+ yield "Sorry, there was an error processing your request."
582
 
583
 
584
  def create_chat_interface():
 
712
  iframe_html = "<iframe id='link-frame'></iframe>"
713
  gr.HTML(value=iframe_html) # Add iframe to the UI
714
 
715
+ def respond(
716
+ message: str,
717
+ chat_history: List,
718
+ transcript_processor,
719
+ cid,
720
+ rsid,
721
+ origin,
722
+ ct,
723
+ ):
724
+ if not transcript_processor:
725
+ bot_message = "Transcript processor not initialized."
726
+ chat_history.append((message, bot_message))
727
+ return "", chat_history
728
+
729
+ chat_history.append((message, ""))
730
+ for chunk in chat(
731
+ message,
732
+ chat_history[:-1], # Exclude the current incomplete message
733
+ transcript_processor,
734
+ cid,
735
+ rsid,
736
+ origin,
737
+ ct,
738
+ ):
739
+ chat_history[-1] = (message, chunk)
740
+ yield "", chat_history
741
+
742
+ msg.submit(
743
+ respond,
744
+ [
745
+ msg,
746
+ chatbot,
747
+ transcript_processor_state,
748
+ call_id_state,
749
+ colab_id_state,
750
+ origin_state,
751
+ ct_state,
752
+ ],
753
+ [msg, chatbot],
754
+ )
755
+
756
+ # Handle initial loading with streaming
757
  def on_app_load(request: gr.Request):
758
  cid = request.query_params.get("cid", None)
759
  rsid = request.query_params.get("rsid", None)
760
  origin = request.query_params.get("origin", None)
761
  ct = request.query_params.get("ct", None)
762
  turl = request.query_params.get("turl", None)
763
+
764
  required_params = ["cid", "rsid", "origin", "ct", "turl"]
765
  missing_params = [
766
  param
767
  for param in required_params
768
  if request.query_params.get(param) is None
769
  ]
770
+
771
  if missing_params:
772
  error_message = (
773
  f"Missing required parameters: {', '.join(missing_params)}"
 
783
  None,
784
  ]
785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786
  try:
787
  transcript_data = get_transcript_for_url(turl)
788
  transcript_processor = TranscriptProcessor(
789
  transcript_data=transcript_data
790
  )
 
 
 
791
 
792
+ # Initialize with empty message
793
+ chatbot_value = [(None, "")]
 
794
 
795
+ # Return initial values with the transcript processor
796
  return [
797
  chatbot_value,
798
  transcript_processor,
 
815
  None,
816
  ]
817
 
818
+ def stream_initial_analysis(
819
+ chatbot_value, transcript_processor, cid, rsid, origin, ct
820
+ ):
821
+ if transcript_processor:
822
+ for chunk in get_initial_analysis(
823
+ transcript_processor, cid, rsid, origin, ct
824
+ ):
825
+ chatbot_value[0] = (None, chunk)
826
+ yield chatbot_value
827
+ else:
828
+ yield chatbot_value
829
+
830
+ # Modified load event to handle streaming
831
  demo.load(
832
  on_app_load,
833
  inputs=None,
 
840
  ct_state,
841
  turl_state,
842
  ],
843
+ ).then(
844
+ stream_initial_analysis,
845
+ inputs=[
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
846
  chatbot,
847
  transcript_processor_state,
848
  call_id_state,
 
850
  origin_state,
851
  ct_state,
852
  ],
853
+ outputs=[chatbot],
854
  )
855
 
856
  return demo