Spaces:
Sleeping
Sleeping
saifeddinemk
commited on
Commit
•
7d0de60
1
Parent(s):
5aaa320
Fixed app v2
Browse files
app.py
CHANGED
@@ -1,78 +1,43 @@
|
|
1 |
-
import torch
|
2 |
-
import json
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
from pydantic import BaseModel
|
5 |
-
from transformers import
|
6 |
from typing import List
|
7 |
|
8 |
# Initialize the FastAPI app
|
9 |
app = FastAPI()
|
10 |
|
11 |
-
#
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
model = AutoModelForCausalLM.from_pretrained(
|
16 |
-
model_path,
|
17 |
-
torch_dtype=torch.float16,
|
18 |
-
device_map="auto",
|
19 |
-
load_in_4bit=False,
|
20 |
-
trust_remote_code=False,
|
21 |
-
)
|
22 |
-
|
23 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
24 |
-
|
25 |
-
# Function to generate text
|
26 |
-
def generate_text(instruction):
|
27 |
-
tokens = tokenizer.encode(instruction)
|
28 |
-
tokens = torch.LongTensor(tokens).unsqueeze(0)
|
29 |
-
tokens = tokens.to("cuda")
|
30 |
-
|
31 |
-
instance = {
|
32 |
-
"input_ids": tokens,
|
33 |
-
"top_p": 1.0,
|
34 |
-
"temperature": 0.75,
|
35 |
-
"generate_len": 2048,
|
36 |
-
"top_k": 50,
|
37 |
-
}
|
38 |
-
|
39 |
-
length = len(tokens[0])
|
40 |
-
with torch.no_grad():
|
41 |
-
rest = model.generate(
|
42 |
-
input_ids=tokens,
|
43 |
-
max_length=length + instance["generate_len"],
|
44 |
-
use_cache=True,
|
45 |
-
do_sample=True,
|
46 |
-
top_p=instance["top_p"],
|
47 |
-
temperature=instance["temperature"],
|
48 |
-
top_k=instance["top_k"],
|
49 |
-
num_return_sequences=1,
|
50 |
-
pad_token_id=tokenizer.eos_token_id,
|
51 |
-
)
|
52 |
-
output = rest[0][length:]
|
53 |
-
string = tokenizer.decode(output, skip_special_tokens=True)
|
54 |
-
return f"{string}"
|
55 |
|
56 |
# Data model for FastAPI input
|
57 |
class UserInput(BaseModel):
|
58 |
conversation: str
|
59 |
user_input: str
|
|
|
|
|
60 |
|
61 |
@app.post("/generate/")
|
62 |
async def generate_response(user_input: UserInput):
|
63 |
try:
|
64 |
-
# Construct the prompt
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
# Generate response
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
73 |
|
|
|
|
|
|
|
74 |
return {
|
75 |
-
"
|
76 |
"updated_conversation": updated_conversation
|
77 |
}
|
78 |
except Exception as e:
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
+
from transformers import pipeline, set_seed
|
4 |
from typing import List
|
5 |
|
6 |
# Initialize the FastAPI app
|
7 |
app = FastAPI()
|
8 |
|
9 |
+
# Initialize the generator pipeline
|
10 |
+
generator = pipeline('text-generation', model='gpt2-medium')
|
11 |
+
set_seed(42)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# Data model for FastAPI input
|
14 |
class UserInput(BaseModel):
|
15 |
conversation: str
|
16 |
user_input: str
|
17 |
+
max_length: int = 50 # default length
|
18 |
+
num_return_sequences: int = 1 # default number of sequences
|
19 |
|
20 |
@app.post("/generate/")
|
21 |
async def generate_response(user_input: UserInput):
|
22 |
try:
|
23 |
+
# Construct the prompt from the conversation and user input
|
24 |
+
prompt = f"{user_input.conversation}{user_input.user_input}"
|
25 |
+
|
|
|
26 |
# Generate response
|
27 |
+
responses = generator(
|
28 |
+
prompt,
|
29 |
+
max_length=user_input.max_length,
|
30 |
+
num_return_sequences=user_input.num_return_sequences
|
31 |
+
)
|
32 |
+
|
33 |
+
# Extract text from each generated sequence
|
34 |
+
generated_texts = [response["generated_text"] for response in responses]
|
35 |
|
36 |
+
# Update conversation with the last generated text
|
37 |
+
updated_conversation = f"{prompt}\n{generated_texts[-1]}"
|
38 |
+
|
39 |
return {
|
40 |
+
"responses": generated_texts,
|
41 |
"updated_conversation": updated_conversation
|
42 |
}
|
43 |
except Exception as e:
|