Spaces:
Build error
Build error
acecalisto3
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
from typing import Dict, List, Optional
|
|
|
3 |
from transformers import (
|
4 |
AutoConfig,
|
5 |
AutoModelForSequenceClassification,
|
@@ -10,6 +11,17 @@ from transformers import (
|
|
10 |
TrainingArguments,
|
11 |
)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
class MockOpenAI:
|
14 |
"""
|
15 |
A mock implementation of OpenAI's API using Hugging Face's pipeline for text generation.
|
@@ -96,11 +108,9 @@ class MockOpenAI:
|
|
96 |
|
97 |
# Example usage
|
98 |
if __name__ == "__main__":
|
99 |
-
parser = HfArgumentParser(
|
100 |
-
|
101 |
-
|
102 |
-
args = parser.parse_args()
|
103 |
-
client = MockOpenAI(model_name=args.model_name, max_tokens=args.max_tokens)
|
104 |
chat_completion = client.chat.Completions().create(
|
105 |
messages=[
|
106 |
{
|
@@ -114,4 +124,4 @@ if __name__ == "__main__":
|
|
114 |
]
|
115 |
)
|
116 |
|
117 |
-
print(chat_completion)
|
|
|
1 |
import os
|
2 |
from typing import Dict, List, Optional
|
3 |
+
from dataclasses import dataclass, field
|
4 |
from transformers import (
|
5 |
AutoConfig,
|
6 |
AutoModelForSequenceClassification,
|
|
|
11 |
TrainingArguments,
|
12 |
)
|
13 |
|
14 |
+
@dataclass
|
15 |
+
class ModelArguments:
|
16 |
+
model_name: str = field(
|
17 |
+
default="gpt2",
|
18 |
+
metadata={"help": "The name of the pretrained model to use for text generation."}
|
19 |
+
)
|
20 |
+
max_tokens: int = field(
|
21 |
+
default=50,
|
22 |
+
metadata={"help": "The maximum number of tokens to generate in the response."}
|
23 |
+
)
|
24 |
+
|
25 |
class MockOpenAI:
|
26 |
"""
|
27 |
A mock implementation of OpenAI's API using Hugging Face's pipeline for text generation.
|
|
|
108 |
|
109 |
# Example usage
|
110 |
if __name__ == "__main__":
|
111 |
+
parser = HfArgumentParser((ModelArguments,))
|
112 |
+
model_args = parser.parse_args_into_dataclasses()[0]
|
113 |
+
client = MockOpenAI(model_name=model_args.model_name, max_tokens=model_args.max_tokens)
|
|
|
|
|
114 |
chat_completion = client.chat.Completions().create(
|
115 |
messages=[
|
116 |
{
|
|
|
124 |
]
|
125 |
)
|
126 |
|
127 |
+
print(chat_completion)
|