SpiceyToad commited on
Commit
0abf936
1 Parent(s): f1cf6cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -51
app.py CHANGED
@@ -1,51 +1,59 @@
1
- import os
2
- from fastapi import FastAPI, Request
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import torch
5
-
6
- # Set Hugging Face cache directory
7
- os.environ["HF_HOME"] = "/home/user/cache"
8
-
9
- # Get Hugging Face API token
10
- HF_API_TOKEN = os.getenv("HF_API_TOKEN")
11
- if not HF_API_TOKEN:
12
- raise ValueError("HF_API_TOKEN environment variable is not set!")
13
-
14
- app = FastAPI()
15
-
16
- # Load Falcon 7B model
17
- MODEL_NAME = "SpiceyToad/demo-falc"
18
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_API_TOKEN)
19
- model = AutoModelForCausalLM.from_pretrained(
20
- MODEL_NAME,
21
- device_map="auto",
22
- torch_dtype=torch.bfloat16,
23
- token=HF_API_TOKEN
24
- )
25
-
26
- # Ensure tokenizer has a padding token
27
- if tokenizer.pad_token is None:
28
- tokenizer.pad_token = tokenizer.eos_token # Use the EOS token as the padding token
29
-
30
- @app.post("/generate")
31
- async def generate_text(request: Request):
32
- data = await request.json()
33
- prompt = data.get("prompt", "")
34
- max_length = data.get("max_length", 50)
35
-
36
- # Tokenize with padding and attention mask
37
- inputs = tokenizer(
38
- prompt,
39
- return_tensors="pt",
40
- padding=True,
41
- truncation=True
42
- ).to(model.device)
43
-
44
- outputs = model.generate(
45
- inputs["input_ids"],
46
- attention_mask=inputs["attention_mask"],
47
- max_length=max_length
48
- )
49
-
50
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
- return {"generated_text": response}
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, Request
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+
6
+ # Set Hugging Face cache directory
7
+ os.environ["HF_HOME"] = "/home/user/cache"
8
+
9
+ # Get Hugging Face API token
10
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
11
+ if not HF_API_TOKEN:
12
+ raise ValueError("HF_API_TOKEN environment variable is not set!")
13
+
14
+ app = FastAPI()
15
+
16
+ # Load Falcon 7B model
17
+ MODEL_NAME = "SpiceyToad/demo-falc"
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_API_TOKEN)
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ MODEL_NAME,
21
+ device_map="auto",
22
+ torch_dtype=torch.bfloat16,
23
+ use_auth_token=HF_API_TOKEN
24
+ )
25
+
26
+ # Ensure tokenizer has a padding token
27
+ if tokenizer.pad_token is None:
28
+ tokenizer.pad_token = tokenizer.eos_token # Use the EOS token as the padding token
29
+
30
+ @app.post("/generate")
31
+ async def generate_text(request: Request):
32
+ data = await request.json()
33
+ prompt = data.get("prompt", "").strip()
34
+ max_length = data.get("max_length", 50)
35
+
36
+ if not prompt:
37
+ return {"error": "Prompt is required!"}
38
+
39
+ # Validate max_length
40
+ max_length = min(max_length, model.config.max_position_embeddings)
41
+
42
+ # Tokenize with padding and attention mask
43
+ inputs = tokenizer(
44
+ prompt,
45
+ return_tensors="pt",
46
+ padding=True,
47
+ truncation=True,
48
+ max_length=max_length
49
+ ).to(model.device)
50
+
51
+ # Generate response
52
+ outputs = model.generate(
53
+ inputs["input_ids"],
54
+ attention_mask=inputs["attention_mask"],
55
+ max_length=max_length
56
+ )
57
+
58
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
59
+ return {"generated_text": response}