acecalisto3 commited on
Commit
e6190d8
·
verified ·
1 Parent(s): 3d4cea4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  from typing import Dict, List, Optional
 
3
  from transformers import (
4
  AutoConfig,
5
  AutoModelForSequenceClassification,
@@ -10,6 +11,17 @@ from transformers import (
10
  TrainingArguments,
11
  )
12
 
 
 
 
 
 
 
 
 
 
 
 
13
  class MockOpenAI:
14
  """
15
  A mock implementation of OpenAI's API using Hugging Face's pipeline for text generation.
@@ -96,11 +108,9 @@ class MockOpenAI:
96
 
97
  # Example usage
98
  if __name__ == "__main__":
99
- parser = HfArgumentParser(description="Mock OpenAI API using Hugging Face's pipeline for text generation.")
100
- parser.add_argument("--model_name", default="gpt2", help="The name of the pretrained model to use for text generation.")
101
- parser.add_argument("--max_tokens", type=int, default=50, help="The maximum number of tokens to generate in the response.")
102
- args = parser.parse_args()
103
- client = MockOpenAI(model_name=args.model_name, max_tokens=args.max_tokens)
104
  chat_completion = client.chat.Completions().create(
105
  messages=[
106
  {
@@ -114,4 +124,4 @@ if __name__ == "__main__":
114
  ]
115
  )
116
 
117
- print(chat_completion)
 
1
  import os
2
  from typing import Dict, List, Optional
3
+ from dataclasses import dataclass, field
4
  from transformers import (
5
  AutoConfig,
6
  AutoModelForSequenceClassification,
 
11
  TrainingArguments,
12
  )
13
 
14
+ @dataclass
15
+ class ModelArguments:
16
+ model_name: str = field(
17
+ default="gpt2",
18
+ metadata={"help": "The name of the pretrained model to use for text generation."}
19
+ )
20
+ max_tokens: int = field(
21
+ default=50,
22
+ metadata={"help": "The maximum number of tokens to generate in the response."}
23
+ )
24
+
25
  class MockOpenAI:
26
  """
27
  A mock implementation of OpenAI's API using Hugging Face's pipeline for text generation.
 
108
 
109
  # Example usage
110
  if __name__ == "__main__":
111
+ parser = HfArgumentParser((ModelArguments,))
112
+ model_args = parser.parse_args_into_dataclasses()[0]
113
+ client = MockOpenAI(model_name=model_args.model_name, max_tokens=model_args.max_tokens)
 
 
114
  chat_completion = client.chat.Completions().create(
115
  messages=[
116
  {
 
124
  ]
125
  )
126
 
127
+ print(chat_completion)