Youssefk commited on
Commit
bf0925c
1 Parent(s): adb75cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -57
app.py CHANGED
@@ -7,7 +7,7 @@ from transformers import AutoModelWithLMHead, AutoTokenizer
7
  # model = AutoModelWithLMHead.from_pretrained('model.py')
8
 
9
  # -*- coding: utf-8 -*-
10
-
11
  import pandas as pd
12
 
13
  data = {'Question': ['What is the story about?',
@@ -40,7 +40,6 @@ df = pd.DataFrame(data)
40
 
41
  # ! pip -q install transformers
42
 
43
- from transformers import AutoModelWithLMHead, AutoTokenizer
44
  import torch
45
  import os
46
 
@@ -635,58 +634,58 @@ print(len(test_chatbot))
635
  ####################################
636
  ############Streamlit###############
637
 
638
- st.set_page_config(
639
- page_title="COVID Doctor using DialoGPT",
640
- page_icon=":robot:"
641
- )
642
-
643
- API_URL = "https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
644
- #headers = {"Authorization": st.secrets['api_key']}
645
-
646
- st.header("Hello - Welcome to COVID Doctor using DialoGPT")
647
- st.markdown("[Github](https://github.com/rushic24/DialoGPT-Finetune)")
648
-
649
- if 'generated' not in st.session_state:
650
- st.session_state['generated'] = []
651
-
652
- if 'past' not in st.session_state:
653
- st.session_state['past'] = []
654
-
655
- def query(payload):
656
- bot_input_ids = tokenizer.encode(payload["inputs"]["text"] + tokenizer.eos_token, return_tensors='pt')
657
-
658
- chat_history_ids = model.generate(
659
- bot_input_ids, max_length=100,
660
- pad_token_id=tokenizer.eos_token_id,
661
- no_repeat_ngram_size=3,
662
- do_sample=True,
663
- top_k=10,
664
- top_p=0.7,
665
- temperature = 0.8
666
- )
667
- output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
668
- return {"generated_text": output}
669
-
670
- def get_text():
671
- input_text = st.text_input("You: "," ", key="input")
672
- return input_text
673
-
674
-
675
- user_input = get_text()
676
-
677
- if user_input:
678
- output = query({
679
- "inputs": {
680
- "past_user_inputs": st.session_state.past,
681
- "generated_responses": st.session_state.generated,
682
- "text": user_input,
683
- },"parameters": {"repetition_penalty": 1.33},
684
- })
685
- st.session_state.past.append(user_input)
686
- st.session_state.generated.append(output["generated_text"])
687
-
688
- if st.session_state['generated']:
689
-
690
- for i in range(len(st.session_state['generated'])-1, -1, -1):
691
- message(st.session_state["generated"][i], key=str(i))
692
- message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
 
7
  # model = AutoModelWithLMHead.from_pretrained('model.py')
8
 
9
  # -*- coding: utf-8 -*-
10
+ st.write("yoyoyo")
11
  import pandas as pd
12
 
13
  data = {'Question': ['What is the story about?',
 
40
 
41
  # ! pip -q install transformers
42
 
 
43
  import torch
44
  import os
45
 
 
634
  ####################################
635
  ############Streamlit###############
636
 
637
+ # st.set_page_config(
638
+ # page_title="COVID Doctor using DialoGPT",
639
+ # page_icon=":robot:"
640
+ # )
641
+
642
+ # API_URL = "https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
643
+ # #headers = {"Authorization": st.secrets['api_key']}
644
+
645
+ # st.header("Hello - Welcome to COVID Doctor using DialoGPT")
646
+ # st.markdown("[Github](https://github.com/rushic24/DialoGPT-Finetune)")
647
+
648
+ # if 'generated' not in st.session_state:
649
+ # st.session_state['generated'] = []
650
+
651
+ # if 'past' not in st.session_state:
652
+ # st.session_state['past'] = []
653
+
654
+ # def query(payload):
655
+ # bot_input_ids = tokenizer.encode(payload["inputs"]["text"] + tokenizer.eos_token, return_tensors='pt')
656
+
657
+ # chat_history_ids = model.generate(
658
+ # bot_input_ids, max_length=100,
659
+ # pad_token_id=tokenizer.eos_token_id,
660
+ # no_repeat_ngram_size=3,
661
+ # do_sample=True,
662
+ # top_k=10,
663
+ # top_p=0.7,
664
+ # temperature = 0.8
665
+ # )
666
+ # output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
667
+ # return {"generated_text": output}
668
+
669
+ # def get_text():
670
+ # input_text = st.text_input("You: "," ", key="input")
671
+ # return input_text
672
+
673
+
674
+ # user_input = get_text()
675
+
676
+ # if user_input:
677
+ # output = query({
678
+ # "inputs": {
679
+ # "past_user_inputs": st.session_state.past,
680
+ # "generated_responses": st.session_state.generated,
681
+ # "text": user_input,
682
+ # },"parameters": {"repetition_penalty": 1.33},
683
+ # })
684
+ # st.session_state.past.append(user_input)
685
+ # st.session_state.generated.append(output["generated_text"])
686
+
687
+ # if st.session_state['generated']:
688
+
689
+ # for i in range(len(st.session_state['generated'])-1, -1, -1):
690
+ # message(st.session_state["generated"][i], key=str(i))
691
+ # message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')