Prathap commited on
Commit
a2d38fb
1 Parent(s): af9e942

Upload gpt.py

Browse files
Files changed (1) hide show
  1. gpt.py +83 -0
gpt.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Creates the Example and GPT classes for a user to interface with the OpenAI API."""
2
+
3
+ import openai
4
+
5
+
6
+ def set_openai_key(key):
7
+ """Sets OpenAI key."""
8
+ openai.api_key = key
9
+
10
+ class Example():
11
+ """Stores an input, output pair and formats it to prime the model."""
12
+
13
+ def __init__(self, inp, out):
14
+ self.input = inp
15
+ self.output = out
16
+
17
+ def get_input(self):
18
+ """Returns the input of the example."""
19
+ return self.input
20
+
21
+ def get_output(self):
22
+ """Returns the intended output of the example."""
23
+ return self.output
24
+
25
+ def format(self):
26
+ """Formats the input, output pair."""
27
+ return f"input: {self.input}\noutput: {self.output}\n"
28
+
29
+
30
+ class GPT:
31
+ """The main class for a user to interface with the OpenAI API.
32
+ A user can add examples and set parameters of the API request."""
33
+
34
+ def __init__(self, engine='davinci',
35
+ temperature=0.5,
36
+ max_tokens=100):
37
+ self.examples = []
38
+ self.engine = engine
39
+ self.temperature = temperature
40
+ self.max_tokens = max_tokens
41
+
42
+ def add_example(self, ex):
43
+ """Adds an example to the object. Example must be an instance
44
+ of the Example class."""
45
+ assert isinstance(ex, Example), "Please create an Example object."
46
+ self.examples.append(ex.format())
47
+
48
+ def get_prime_text(self):
49
+ """Formats all examples to prime the model."""
50
+ return '\n'.join(self.examples) + '\n'
51
+
52
+ def get_engine(self):
53
+ """Returns the engine specified for the API."""
54
+ return self.engine
55
+
56
+ def get_temperature(self):
57
+ """Returns the temperature specified for the API."""
58
+ return self.temperature
59
+
60
+ def get_max_tokens(self):
61
+ """Returns the max tokens specified for the API."""
62
+ return self.max_tokens
63
+
64
+ def craft_query(self, prompt):
65
+ """Creates the query for the API request."""
66
+ return self.get_prime_text() + "input: " + prompt + "\n"
67
+
68
+ def submit_request(self, prompt):
69
+ """Calls the OpenAI API with the specified parameters."""
70
+ response = openai.Completion.create(engine=self.get_engine(),
71
+ prompt=self.craft_query(prompt),
72
+ max_tokens=self.get_max_tokens(),
73
+ temperature=self.get_temperature(),
74
+ top_p=1,
75
+ n=1,
76
+ stream=False,
77
+ stop="\ninput:")
78
+ return response
79
+
80
+ def get_top_reply(self, prompt):
81
+ """Obtains the best result as returned by the API."""
82
+ response = self.submit_request(prompt)
83
+ return response['choices'][0]['text']