mskov Firefly777a commited on
Commit
1575629
1 Parent(s): ed8df2e

Major changes to the app to allow prompt engineering (#2)

Browse files

- Major changes to the app to allow prompt engineering (ad35daaa7d851793c89104eff8bf4912e5c2dc76)


Co-authored-by: Maddie <Firefly777a@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +47 -82
app.py CHANGED
@@ -1,14 +1,10 @@
1
 
2
  '''
3
- This script calls the ada model from openai api to predict the next few words.
4
  '''
5
  import os
6
- os.system("pip install --upgrade pip")
7
  from pprint import pprint
8
- os.system("pip install git+https://github.com/openai/whisper.git")
9
  import sys
10
- print("Sys: ", sys.executable)
11
- os.system("pip install openai")
12
  import openai
13
  import gradio as gr
14
  import whisper
@@ -17,68 +13,47 @@ import torch
17
  from transformers import AutoModelForCausalLM
18
  from transformers import AutoTokenizer
19
  import time
20
- # import streaming.py
21
- # from next_word_prediction import GPT2
22
 
23
-
24
-
25
-
26
- #gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
27
- #tokenizer = AutoTokenizer.from_pretrained("gpt2")
28
-
29
- ### /code snippet
30
-
31
-
32
- # get gpt2 model
33
- #generator = pipeline('text-generation', model='gpt2')
34
-
35
- # whisper model specification
36
- model = whisper.load_model("tiny")
37
-
38
-
39
-
40
- def inference(audio, state=""):
41
-
42
- #time.sleep(2)
43
- #text = p(audio)["text"]
44
- #state += text + " "
45
- # load audio data
46
- audio = whisper.load_audio(audio)
47
- # ensure sample is in correct format for inference
48
- audio = whisper.pad_or_trim(audio)
49
-
50
- # generate a log-mel spetrogram of the audio data
51
- mel = whisper.log_mel_spectrogram(audio).to(model.device)
52
 
53
- _, probs = model.detect_language(mel)
54
-
55
- # decode audio data
56
- options = whisper.DecodingOptions(fp16 = False)
57
- # transcribe speech to text
58
- result = whisper.decode(model, mel, options)
59
- print("result pre gp model from whisper: ", result, ".text ", result.text, "and the data type: ", type(result.text))
60
-
61
- PROMPT = """The following is an incomplete transcript of a brief conversation. Predict a list of the next most probable words to complete the sentence.
62
- Some examples:
63
- Transcript1: Tomorrow night we're going out to
64
- Predictions1: the movies, a restaurant, a baseball game, the theater, a party for a friend
65
- Transcript2: I would like to order a cheeseburger with a side of
66
- Predictions2: french fries, milkshake, apple slices, salad, extra catsup
67
- Transcript3: My friend Savanah is
68
- Predictions3: an electrical engineer, a marine biologist, a classical musician
69
- Transcript4: I need to buy a birthday
70
- Predictions4: present, gift, cake, card
71
-
72
- Transcript5: """
73
- text = PROMPT + result.text + "Prediction5: "
 
 
 
 
 
74
 
75
- openai.api_key = os.environ["Openai_APIkey"]
76
 
77
  response = openai.Completion.create(
78
- model="text-ada-001",
79
- #model="text-curie-001",
80
  prompt=text,
81
- temperature=1,
82
  max_tokens=8,
83
  n=5)
84
 
@@ -96,27 +71,17 @@ Transcript5: """
96
  infers = list(map(lambda x: x.replace("\n", ""), temp))
97
  #infered = list(map(lambda x: x.split(','), infers))
98
 
99
-
100
-
101
-
102
- # result.text
103
- #return getText, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
104
- return result.text, state, infers
105
-
106
-
107
 
108
  # get audio from microphone
109
-
110
  gr.Interface(
111
- fn=inference,
112
- inputs=[
113
- gr.inputs.Audio(source="microphone", type="filepath"),
114
- "state"
115
- ],
116
- outputs=[
117
- "textbox",
118
- "state",
119
- "textbox"
120
- ],
121
- live=True).launch()
122
-
 
1
 
2
  '''
3
+ This script calls the model from openai api to predict the next few words.
4
  '''
5
  import os
 
6
  from pprint import pprint
 
7
  import sys
 
 
8
  import openai
9
  import gradio as gr
10
  import whisper
 
13
  from transformers import AutoModelForCausalLM
14
  from transformers import AutoTokenizer
15
  import time
 
 
16
 
17
+ EXAMPLE_PROMPT = """This is a tool for helping someone with memory issues remember the next word.
18
+ The predictions follow a few rules:
19
+ 1) The predictions are suggestions of ways to continue the transcript as if someone forgot what the next word was.
20
+ 2) The predictions do not repeat themselves.
21
+ 3) The predictions focus on suggesting nouns, adjectives, and verbs.
22
+ 4) The predictions are related to the context in the transcript.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ EXAMPLES:
25
+ Transcript: Tomorrow night we're going out to
26
+ Prediction: The Movies, A Restaurant, A Baseball Game, The Theater, A Party for a friend
27
+ Transcript: I would like to order a cheeseburger with a side of
28
+ Prediction: Frnech fries, Milkshake, Apple slices, Side salad, Extra katsup
29
+ Transcript: My friend Savanah is
30
+ Prediction: An elecrical engineer, A marine biologist, A classical musician
31
+ Transcript: I need to buy a birthday
32
+ Prediction: Present, Gift, Cake, Card
33
+ Transcript: """
34
+
35
+ # whisper model specification
36
+ asr_model = whisper.load_model("tiny")
37
+
38
+ openai.api_key = os.environ["Openai_APIkey"]
39
+
40
+ # Transcribe function
41
+ def transcribe(audio_file):
42
+ print("Transcribing")
43
+ transcription = asr_model.transcribe(audio_file)["text"]
44
+ return transcription
45
+
46
+ def debug_inference(audio, prompt, model, temperature, state=""):
47
+ # Transcribe with Whisper
48
+ print("The audio is:", audio)
49
+ transcript = transcribe(audio)
50
 
51
+ text = prompt + transcript + "\nPrediction: "
52
 
53
  response = openai.Completion.create(
54
+ model=model,
 
55
  prompt=text,
56
+ temperature=temperature,
57
  max_tokens=8,
58
  n=5)
59
 
 
71
  infers = list(map(lambda x: x.replace("\n", ""), temp))
72
  #infered = list(map(lambda x: x.split(','), infers))
73
 
74
+ return transcript, state, infers, text
 
 
 
 
 
 
 
75
 
76
  # get audio from microphone
 
77
  gr.Interface(
78
+ fn=debug_inference,
79
+ inputs=[gr.inputs.Audio(source="microphone", type="filepath"),
80
+ gr.inputs.Textbox(lines=15, placeholder="Enter a prompt here"),
81
+ gr.inputs.Dropdown(["text-ada-001", "text-davinci-002", "text-davinci-003", "gpt-3.5-turbo"], label="Model"),
82
+ gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.8, step=0.1, label="Temperature"),
83
+ "state"
84
+ ],
85
+ outputs=["textbox","state","textbox", "textbox"],
86
+ # examples=[["example_in-the-mood-to-eat.m4a", EXAMPLE_PROMPT, "text-ada-001", 0.8, ""],["","","",0.9,""]],
87
+ live=False).launch()