matdmiller commited on
Commit
1c742b7
1 Parent(s): 951a5b3

fix auth and dropdowns

Browse files
Files changed (2) hide show
  1. app.ipynb +68 -37
  2. app.py +9 -5
app.ipynb CHANGED
@@ -539,6 +539,7 @@
539
  "#| export\n",
540
  "def create_speech(input_text, provider, model='tts-1', voice='alloy', profile: gr.OAuthProfile|None=None, progress=gr.Progress(), **kwargs):\n",
541
  "\n",
 
542
  " verify_authorization(profile)\n",
543
  " start = datetime.now()\n",
544
  "\n",
@@ -683,21 +684,10 @@
683
  },
684
  {
685
  "cell_type": "code",
686
- "execution_count": 25,
687
  "id": "e4fb3159-579b-4271-bc96-4cd1e2816eca",
688
  "metadata": {},
689
- "outputs": [
690
- {
691
- "name": "stderr",
692
- "output_type": "stream",
693
- "text": [
694
- "/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/utils.py:1000: UserWarning: Expected 3 arguments for function <function get_generation_cost at 0x1174f2e80>, received 4.\n",
695
- " warnings.warn(\n",
696
- "/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/utils.py:1008: UserWarning: Expected maximum 3 arguments for function <function get_generation_cost at 0x1174f2e80>, received 4.\n",
697
- " warnings.warn(\n"
698
- ]
699
- }
700
- ],
701
  "source": [
702
  "#| export\n",
703
  "with gr.Blocks(title='TTS', head='TTS', delete_cache=(3600,3600)) as app:\n",
@@ -707,9 +697,12 @@
707
  " with gr.Row():\n",
708
  " input_text = gr.Textbox(max_lines=100, label=\"Enter text here\")\n",
709
  " with gr.Row():\n",
710
- " tts_provider_dropdown = gr.Dropdown(value=DEFAULT_PROVIDER,choices=[(v,k) for k,v in providers.items()], label='Provider')\n",
711
- " tts_model_dropdown = gr.Dropdown(value=DEFAULT_MODEL,choices=get_model_choices(DEFAULT_PROVIDER), label='Model')\n",
712
- " tts_voice_dropdown = gr.Dropdown(value=DEFAULT_VOICE,choices=get_voice_choices(DEFAULT_PROVIDER, DEFAULT_MODEL),label='Voice')\n",
 
 
 
713
  " input_text_length = gr.Label(label=\"Number of characters\")\n",
714
  " generation_cost = gr.Label(label=\"Generation cost\")\n",
715
  " with gr.Row():\n",
@@ -718,8 +711,8 @@
718
  " #input_text \n",
719
  " input_text.input(fn=get_input_text_len, inputs=input_text, outputs=input_text_length)\n",
720
  " input_text.input(fn=get_generation_cost, \n",
721
- " inputs=[input_text,tts_model_dropdown,tts_provider_dropdown, tts_provider_dropdown], \n",
722
- " outputs=tts_voice_dropdown)\n",
723
  "\n",
724
  " tts_provider_dropdown.change(fn=update_model_choices, inputs=[tts_provider_dropdown], \n",
725
  " outputs=tts_model_dropdown)\n",
@@ -746,7 +739,7 @@
746
  },
747
  {
748
  "cell_type": "code",
749
- "execution_count": 26,
750
  "id": "a00648a1-891b-470b-9959-f5d502055713",
751
  "metadata": {},
752
  "outputs": [],
@@ -760,7 +753,7 @@
760
  },
761
  {
762
  "cell_type": "code",
763
- "execution_count": 27,
764
  "id": "4b534fe7-4337-423e-846a-1bdb7cccc4ea",
765
  "metadata": {},
766
  "outputs": [
@@ -789,9 +782,59 @@
789
  "data": {
790
  "text/plain": []
791
  },
792
- "execution_count": 27,
793
  "metadata": {},
794
  "output_type": "execute_result"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
795
  }
796
  ],
797
  "source": [
@@ -817,7 +860,7 @@
817
  },
818
  {
819
  "cell_type": "code",
820
- "execution_count": 28,
821
  "id": "28e8d888-e790-46fa-bbac-4511b9ab796c",
822
  "metadata": {},
823
  "outputs": [
@@ -836,22 +879,10 @@
836
  },
837
  {
838
  "cell_type": "code",
839
- "execution_count": 2,
840
  "id": "afbc9699-4d16-4060-88f4-cd1251754cbd",
841
  "metadata": {},
842
- "outputs": [
843
- {
844
- "ename": "NameError",
845
- "evalue": "name 'gr' is not defined",
846
- "output_type": "error",
847
- "traceback": [
848
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
849
- "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
850
- "Cell \u001b[0;32mIn[2], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#| hide\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mgr\u001b[49m\u001b[38;5;241m.\u001b[39mclose_all()\n",
851
- "\u001b[0;31mNameError\u001b[0m: name 'gr' is not defined"
852
- ]
853
- }
854
- ],
855
  "source": [
856
  "#| hide\n",
857
  "gr.close_all()"
@@ -859,7 +890,7 @@
859
  },
860
  {
861
  "cell_type": "code",
862
- "execution_count": 30,
863
  "id": "0420310d-930b-4904-8bd4-3458ad8bdbd3",
864
  "metadata": {},
865
  "outputs": [],
 
539
  "#| export\n",
540
  "def create_speech(input_text, provider, model='tts-1', voice='alloy', profile: gr.OAuthProfile|None=None, progress=gr.Progress(), **kwargs):\n",
541
  "\n",
542
+ " #Verify auth if it is required. This is very important if this is in a HF space. DO NOT DELETE!!!\n",
543
  " verify_authorization(profile)\n",
544
  " start = datetime.now()\n",
545
  "\n",
 
684
  },
685
  {
686
  "cell_type": "code",
687
+ "execution_count": 29,
688
  "id": "e4fb3159-579b-4271-bc96-4cd1e2816eca",
689
  "metadata": {},
690
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
691
  "source": [
692
  "#| export\n",
693
  "with gr.Blocks(title='TTS', head='TTS', delete_cache=(3600,3600)) as app:\n",
 
697
  " with gr.Row():\n",
698
  " input_text = gr.Textbox(max_lines=100, label=\"Enter text here\")\n",
699
  " with gr.Row():\n",
700
+ " tts_provider_dropdown = gr.Dropdown(value=DEFAULT_PROVIDER,\n",
701
+ " choices=tuple([(v['name'],k) for k,v in providers.items()]), label='Provider', interactive=True)\n",
702
+ " tts_model_dropdown = gr.Dropdown(value=DEFAULT_MODEL,choices=get_model_choices(DEFAULT_PROVIDER), \n",
703
+ " label='Model', interactive=True)\n",
704
+ " tts_voice_dropdown = gr.Dropdown(value=DEFAULT_VOICE,choices=get_voice_choices(DEFAULT_PROVIDER, DEFAULT_MODEL),\n",
705
+ " label='Voice', interactive=True)\n",
706
  " input_text_length = gr.Label(label=\"Number of characters\")\n",
707
  " generation_cost = gr.Label(label=\"Generation cost\")\n",
708
  " with gr.Row():\n",
 
711
  " #input_text \n",
712
  " input_text.input(fn=get_input_text_len, inputs=input_text, outputs=input_text_length)\n",
713
  " input_text.input(fn=get_generation_cost, \n",
714
+ " inputs=[input_text,tts_model_dropdown,tts_provider_dropdown], \n",
715
+ " outputs=generation_cost)\n",
716
  "\n",
717
  " tts_provider_dropdown.change(fn=update_model_choices, inputs=[tts_provider_dropdown], \n",
718
  " outputs=tts_model_dropdown)\n",
 
739
  },
740
  {
741
  "cell_type": "code",
742
+ "execution_count": 30,
743
  "id": "a00648a1-891b-470b-9959-f5d502055713",
744
  "metadata": {},
745
  "outputs": [],
 
753
  },
754
  {
755
  "cell_type": "code",
756
+ "execution_count": 31,
757
  "id": "4b534fe7-4337-423e-846a-1bdb7cccc4ea",
758
  "metadata": {},
759
  "outputs": [
 
782
  "data": {
783
  "text/plain": []
784
  },
785
+ "execution_count": 31,
786
  "metadata": {},
787
  "output_type": "execute_result"
788
+ },
789
+ {
790
+ "name": "stderr",
791
+ "output_type": "stream",
792
+ "text": [
793
+ "/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/components/dropdown.py:181: UserWarning: The value passed into gr.Dropdown() is not in the list of choices. Please update the list of choices to include: $0.000 or set allow_custom_value=True.\n",
794
+ " warnings.warn(\n",
795
+ "Traceback (most recent call last):\n",
796
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/queueing.py\", line 532, in process_events\n",
797
+ " response = await route_utils.call_process_api(\n",
798
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
799
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/route_utils.py\", line 276, in call_process_api\n",
800
+ " output = await app.get_blocks().process_api(\n",
801
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
802
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/blocks.py\", line 1928, in process_api\n",
803
+ " result = await self.call_function(\n",
804
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
805
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/blocks.py\", line 1500, in call_function\n",
806
+ " processed_input, progress_index, _ = special_args(\n",
807
+ " ^^^^^^^^^^^^^\n",
808
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/helpers.py\", line 891, in special_args\n",
809
+ " getattr(request, \"session\", {})\n",
810
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/route_utils.py\", line 158, in __getattr__\n",
811
+ " return self.dict_to_obj(getattr(self.request, name))\n",
812
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
813
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/starlette/requests.py\", line 157, in session\n",
814
+ " \"session\" in self.scope\n",
815
+ "AssertionError: SessionMiddleware must be installed to access request.session\n",
816
+ "Traceback (most recent call last):\n",
817
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/queueing.py\", line 532, in process_events\n",
818
+ " response = await route_utils.call_process_api(\n",
819
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
820
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/route_utils.py\", line 276, in call_process_api\n",
821
+ " output = await app.get_blocks().process_api(\n",
822
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
823
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/blocks.py\", line 1928, in process_api\n",
824
+ " result = await self.call_function(\n",
825
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
826
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/blocks.py\", line 1500, in call_function\n",
827
+ " processed_input, progress_index, _ = special_args(\n",
828
+ " ^^^^^^^^^^^^^\n",
829
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/helpers.py\", line 891, in special_args\n",
830
+ " getattr(request, \"session\", {})\n",
831
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/gradio/route_utils.py\", line 158, in __getattr__\n",
832
+ " return self.dict_to_obj(getattr(self.request, name))\n",
833
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
834
+ " File \"/Users/mathewmiller/anaconda3/envs/gradio1/lib/python3.11/site-packages/starlette/requests.py\", line 157, in session\n",
835
+ " \"session\" in self.scope\n",
836
+ "AssertionError: SessionMiddleware must be installed to access request.session\n"
837
+ ]
838
  }
839
  ],
840
  "source": [
 
860
  },
861
  {
862
  "cell_type": "code",
863
+ "execution_count": 33,
864
  "id": "28e8d888-e790-46fa-bbac-4511b9ab796c",
865
  "metadata": {},
866
  "outputs": [
 
879
  },
880
  {
881
  "cell_type": "code",
882
+ "execution_count": 34,
883
  "id": "afbc9699-4d16-4060-88f4-cd1251754cbd",
884
  "metadata": {},
885
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
886
  "source": [
887
  "#| hide\n",
888
  "gr.close_all()"
 
890
  },
891
  {
892
  "cell_type": "code",
893
+ "execution_count": 36,
894
  "id": "0420310d-930b-4904-8bd4-3458ad8bdbd3",
895
  "metadata": {},
896
  "outputs": [],
app.py CHANGED
@@ -227,6 +227,7 @@ def create_speech_cartesiaai(chunk_idx, input, model='upbeat-moon',
227
  # %% app.ipynb 25
228
  def create_speech(input_text, provider, model='tts-1', voice='alloy', profile: gr.OAuthProfile|None=None, progress=gr.Progress(), **kwargs):
229
 
 
230
  verify_authorization(profile)
231
  start = datetime.now()
232
 
@@ -319,9 +320,12 @@ For requests longer than allowed by the API they will be broken into chunks auto
319
  with gr.Row():
320
  input_text = gr.Textbox(max_lines=100, label="Enter text here")
321
  with gr.Row():
322
- tts_provider_dropdown = gr.Dropdown(value=DEFAULT_PROVIDER,choices=[(v,k) for k,v in providers.items()], label='Provider')
323
- tts_model_dropdown = gr.Dropdown(value=DEFAULT_MODEL,choices=get_model_choices(DEFAULT_PROVIDER), label='Model')
324
- tts_voice_dropdown = gr.Dropdown(value=DEFAULT_VOICE,choices=get_voice_choices(DEFAULT_PROVIDER, DEFAULT_MODEL),label='Voice')
 
 
 
325
  input_text_length = gr.Label(label="Number of characters")
326
  generation_cost = gr.Label(label="Generation cost")
327
  with gr.Row():
@@ -330,8 +334,8 @@ For requests longer than allowed by the API they will be broken into chunks auto
330
  #input_text
331
  input_text.input(fn=get_input_text_len, inputs=input_text, outputs=input_text_length)
332
  input_text.input(fn=get_generation_cost,
333
- inputs=[input_text,tts_model_dropdown,tts_provider_dropdown, tts_provider_dropdown],
334
- outputs=tts_voice_dropdown)
335
 
336
  tts_provider_dropdown.change(fn=update_model_choices, inputs=[tts_provider_dropdown],
337
  outputs=tts_model_dropdown)
 
227
  # %% app.ipynb 25
228
  def create_speech(input_text, provider, model='tts-1', voice='alloy', profile: gr.OAuthProfile|None=None, progress=gr.Progress(), **kwargs):
229
 
230
+ #Verify auth if it is required. This is very important if this is in a HF space. DO NOT DELETE!!!
231
  verify_authorization(profile)
232
  start = datetime.now()
233
 
 
320
  with gr.Row():
321
  input_text = gr.Textbox(max_lines=100, label="Enter text here")
322
  with gr.Row():
323
+ tts_provider_dropdown = gr.Dropdown(value=DEFAULT_PROVIDER,
324
+ choices=tuple([(v['name'],k) for k,v in providers.items()]), label='Provider', interactive=True)
325
+ tts_model_dropdown = gr.Dropdown(value=DEFAULT_MODEL,choices=get_model_choices(DEFAULT_PROVIDER),
326
+ label='Model', interactive=True)
327
+ tts_voice_dropdown = gr.Dropdown(value=DEFAULT_VOICE,choices=get_voice_choices(DEFAULT_PROVIDER, DEFAULT_MODEL),
328
+ label='Voice', interactive=True)
329
  input_text_length = gr.Label(label="Number of characters")
330
  generation_cost = gr.Label(label="Generation cost")
331
  with gr.Row():
 
334
  #input_text
335
  input_text.input(fn=get_input_text_len, inputs=input_text, outputs=input_text_length)
336
  input_text.input(fn=get_generation_cost,
337
+ inputs=[input_text,tts_model_dropdown,tts_provider_dropdown],
338
+ outputs=generation_cost)
339
 
340
  tts_provider_dropdown.change(fn=update_model_choices, inputs=[tts_provider_dropdown],
341
  outputs=tts_model_dropdown)