abhisheksan commited on
Commit
67dd542
1 Parent(s): 11be554

Enhance Poetry Generator API; implement health check endpoint, improve model loading with logging, and update request/response models

Browse files
Files changed (4) hide show
  1. app/config.py +16 -0
  2. download_model.py +30 -0
  3. main.py +130 -46
  4. requirements.txt +1 -1
app/config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ # Base project directory
5
+ BASE_DIR = Path(__file__).resolve().parent.parent
6
+
7
+ # Model settings
8
+ MODEL_DIR = BASE_DIR / "models"
9
+ MODEL_NAME = "llama-2-7b-chat.q4_K_M.gguf"
10
+ MODEL_PATH = MODEL_DIR / MODEL_NAME
11
+
12
+ # Ensure model directory exists
13
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
14
+
15
+ # Model download URL
16
+ MODEL_URL = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf"
download_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from tqdm import tqdm
3
+ from app.config import MODEL_PATH, MODEL_URL, MODEL_DIR
4
+ import sys
5
+
6
+ def download_model():
7
+ """Download the model if it doesn't exist"""
8
+ if MODEL_PATH.exists():
9
+ print(f"Model already exists at {MODEL_PATH}")
10
+ return
11
+
12
+ print(f"Downloading model to {MODEL_PATH}")
13
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
14
+
15
+ response = requests.get(MODEL_URL, stream=True)
16
+ total_size = int(response.headers.get('content-length', 0))
17
+
18
+ with open(MODEL_PATH, 'wb') as file, tqdm(
19
+ desc="Downloading",
20
+ total=total_size,
21
+ unit='iB',
22
+ unit_scale=True,
23
+ unit_divisor=1024,
24
+ ) as pbar:
25
+ for data in response.iter_content(chunk_size=1024):
26
+ size = file.write(data)
27
+ pbar.update(size)
28
+
29
+ if __name__ == "__main__":
30
+ download_model()
main.py CHANGED
@@ -1,81 +1,165 @@
1
- from fastapi import FastAPI, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
- from transformers import AutoModelForCausalLM
5
  import time
 
 
6
 
7
- # Initialize FastAPI app
8
- app = FastAPI(title="Poetry Generator")
 
 
 
 
9
 
10
- # Add CORS middleware
11
- app.add_middleware(
12
- CORSMiddleware,
13
- allow_origins=["*"],
14
- allow_credentials=True,
15
- allow_methods=["*"],
16
- allow_headers=["*"],
17
  )
18
 
19
- # Initialize the model (lazy loading)
20
  model = None
21
 
22
- def load_model():
23
- global model
24
- if model is None:
25
- # Load a quantized GGUF model
26
- # You can download models from huggingface.co
27
- # Example: GPT2 or Llama-2-7b-chat.Q4_K_M.gguf
28
- model = AutoModelForCausalLM.from_pretrained(
29
- "TheBloke/Llama-2-7B-Chat-GGUF",
30
- model_file="llama-2-7b-chat.q4_K_M.gguf",
31
- model_type="llama",
32
- max_new_tokens=256,
33
- context_length=512,
34
- gpu_layers=0 # CPU only
35
- )
36
-
37
  class PoetryRequest(BaseModel):
38
- prompt: str
39
- style: str = "free verse"
40
- max_length: int = 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  class PoetryResponse(BaseModel):
43
  poem: str
44
  generation_time: float
 
 
 
 
 
 
 
 
 
45
 
46
  @app.on_event("startup")
47
  async def startup_event():
48
- load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- @app.post("/generate_poem", response_model=PoetryResponse)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  async def generate_poem(request: PoetryRequest):
 
 
 
 
 
 
 
52
  try:
53
  start_time = time.time()
54
 
55
- # Construct the prompt
56
- full_prompt = f"""Write a {request.style} poem about {request.prompt}.
57
- Make it creative and meaningful. The poem should be:
 
 
 
 
58
 
59
- """
 
60
 
61
- # Generate the poem
62
  output = model(
63
  full_prompt,
64
  max_new_tokens=request.max_length,
65
- temperature=0.7,
66
  top_p=0.95,
67
  repeat_penalty=1.2
68
  )
69
 
70
- # Clean up the output
71
- poem = output.strip()
72
-
73
  generation_time = time.time() - start_time
74
 
75
  return PoetryResponse(
76
- poem=poem,
77
- generation_time=generation_time
 
 
78
  )
79
 
80
  except Exception as e:
81
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, status
2
+ from pydantic import BaseModel, Field
3
+ from typing import Optional, List
4
+ from ctransformers import AutoModelForCausalLM
5
  import time
6
+ import logging
7
+ from .app.config import MODEL_PATH
8
 
9
+ # Configure logging
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
13
+ )
14
+ logger = logging.getLogger(__name__)
15
 
16
+ # Initialize FastAPI app
17
+ app = FastAPI(
18
+ title="Poetry Generator API",
19
+ description="An API for generating poetry using a local LLM",
20
+ version="1.0.0"
 
 
21
  )
22
 
23
+ # Global model variable
24
  model = None
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class PoetryRequest(BaseModel):
27
+ prompt: str = Field(..., description="The topic or theme for the poem", min_length=1)
28
+ style: str = Field(
29
+ default="free verse",
30
+ description="Style of the poem to generate"
31
+ )
32
+ max_length: int = Field(
33
+ default=200,
34
+ description="Maximum length of the generated poem",
35
+ ge=50,
36
+ le=500
37
+ )
38
+ temperature: float = Field(
39
+ default=0.7,
40
+ description="Temperature for text generation",
41
+ ge=0.1,
42
+ le=2.0
43
+ )
44
 
45
  class PoetryResponse(BaseModel):
46
  poem: str
47
  generation_time: float
48
+ prompt: str
49
+ style: str
50
+
51
+ class ModelInfo(BaseModel):
52
+ status: str
53
+ model_name: str
54
+ model_path: str
55
+ supported_styles: List[str]
56
+ max_context_length: int
57
 
58
  @app.on_event("startup")
59
  async def startup_event():
60
+ """Initialize the model during startup"""
61
+ global model
62
+ try:
63
+ if not MODEL_PATH.exists():
64
+ raise FileNotFoundError(
65
+ f"Model file not found at {MODEL_PATH}. "
66
+ "Please run download_model.py first."
67
+ )
68
+
69
+ logger.info(f"Loading model from {MODEL_PATH}")
70
+ model = AutoModelForCausalLM.from_pretrained(
71
+ str(MODEL_PATH.parent),
72
+ model_file=MODEL_PATH.name,
73
+ model_type="llama",
74
+ max_new_tokens=512,
75
+ context_length=512,
76
+ gpu_layers=0 # CPU only
77
+ )
78
+ logger.info("Model loaded successfully")
79
+ except Exception as e:
80
+ logger.error(f"Failed to load model: {str(e)}")
81
+ raise RuntimeError("Failed to initialize model")
82
 
83
+ @app.get(
84
+ "/health",
85
+ response_model=ModelInfo,
86
+ status_code=status.HTTP_200_OK,
87
+ tags=["Health Check"]
88
+ )
89
+ async def health_check():
90
+ """Check if the model is loaded and get basic information"""
91
+ if model is None:
92
+ raise HTTPException(
93
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
94
+ detail="Model not loaded"
95
+ )
96
+
97
+ return ModelInfo(
98
+ status="ready",
99
+ model_name="Llama-2-7B-Chat",
100
+ model_path=str(MODEL_PATH),
101
+ supported_styles=[
102
+ "free verse",
103
+ "haiku",
104
+ "sonnet",
105
+ "limerick",
106
+ "tanka"
107
+ ],
108
+ max_context_length=512
109
+ )
110
+
111
+ @app.post(
112
+ "/generate",
113
+ response_model=PoetryResponse,
114
+ status_code=status.HTTP_200_OK,
115
+ tags=["Generation"]
116
+ )
117
  async def generate_poem(request: PoetryRequest):
118
+ """Generate a poem based on the provided prompt and parameters"""
119
+ if model is None:
120
+ raise HTTPException(
121
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
122
+ detail="Model not loaded"
123
+ )
124
+
125
  try:
126
  start_time = time.time()
127
 
128
+ prompt_templates = {
129
+ "haiku": "Write a haiku about {prompt}. Follow the 5-7-5 syllable pattern:\n\n",
130
+ "sonnet": "Write a Shakespearean sonnet about {prompt}. Follow the traditional 14-line format with rhyme scheme ABAB CDCD EFEF GG:\n\n",
131
+ "limerick": "Write a limerick about {prompt}. Follow the AABBA rhyme scheme:\n\n",
132
+ "free verse": "Write a free verse poem about {prompt}. Make it creative and meaningful:\n\n",
133
+ "tanka": "Write a tanka about {prompt}. Follow the 5-7-5-7-7 syllable pattern:\n\n"
134
+ }
135
 
136
+ template = prompt_templates.get(request.style.lower(), prompt_templates["free verse"])
137
+ full_prompt = template.format(prompt=request.prompt)
138
 
 
139
  output = model(
140
  full_prompt,
141
  max_new_tokens=request.max_length,
142
+ temperature=request.temperature,
143
  top_p=0.95,
144
  repeat_penalty=1.2
145
  )
146
 
 
 
 
147
  generation_time = time.time() - start_time
148
 
149
  return PoetryResponse(
150
+ poem=output.strip(),
151
+ generation_time=generation_time,
152
+ prompt=request.prompt,
153
+ style=request.style
154
  )
155
 
156
  except Exception as e:
157
+ logger.error(f"Generation error: {str(e)}")
158
+ raise HTTPException(
159
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
160
+ detail=f"Failed to generate poem: {str(e)}"
161
+ )
162
+
163
+ if __name__ == "__main__":
164
+ import uvicorn
165
+ uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
requirements.txt CHANGED
@@ -18,5 +18,5 @@ accelerate==0.27.2
18
  python-jose==3.3.0 # for JWT handling if you add auth later
19
  gunicorn==21.2.0 # for production deployment
20
  python-dotenv==1.0.0 # for environment variables
21
-
22
  pyllamacpp==2.4.0
 
18
  python-jose==3.3.0 # for JWT handling if you add auth later
19
  gunicorn==21.2.0 # for production deployment
20
  python-dotenv==1.0.0 # for environment variables
21
+ ctransformers
22
  pyllamacpp==2.4.0