abhisheksan commited on
Commit
11be554
·
1 Parent(s): 2107ef1

Refactor poem generation logic; implement lazy loading for model and update request/response models

Browse files
Files changed (2) hide show
  1. main.py +46 -39
  2. models/gpt4all-lora-quantized-ggml.bin +0 -3
main.py CHANGED
@@ -1,9 +1,8 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
- from pyllamacpp.model import Model
5
- import os
6
- from typing import Optional
7
 
8
  # Initialize FastAPI app
9
  app = FastAPI(title="Poetry Generator")
@@ -17,58 +16,66 @@ app.add_middleware(
17
  allow_headers=["*"],
18
  )
19
 
20
- # Model path - adjust this to your model location
21
- MODEL_PATH = "models/gpt4all-lora-quantized-ggml.bin"
22
-
23
- # Initialize the model at startup
24
  model = None
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class PoetryRequest(BaseModel):
27
- theme: str
28
- style: Optional[str] = "free verse"
29
- length: Optional[int] = 100
30
 
31
  class PoetryResponse(BaseModel):
32
  poem: str
 
33
 
34
  @app.on_event("startup")
35
  async def startup_event():
36
- global model
37
- if not os.path.exists(MODEL_PATH):
38
- raise Exception(f"Model file not found at {MODEL_PATH}")
39
- try:
40
- model = Model(
41
- model_path=MODEL_PATH,
42
- )
43
- except Exception as e:
44
- raise Exception(f"Failed to load model: {str(e)}")
45
 
46
  @app.post("/generate_poem", response_model=PoetryResponse)
47
  async def generate_poem(request: PoetryRequest):
48
- if model is None:
49
- raise HTTPException(status_code=500, detail="Model not initialized")
50
-
51
  try:
52
- # Craft the prompt
53
- prompt = f"""Write a {request.style} poem about {request.theme}.
54
- Keep it approximately {request.length} characters long.
55
- Make it creative and meaningful.\n\nPoem:"""
56
 
 
 
 
 
 
 
57
  # Generate the poem
58
- generated_text = model.generate(
59
- prompt,
60
- n_predict=request.length,
61
- temp=0.7,
62
- top_k=40,
63
- top_p=0.9,
64
- repeat_penalty=1.1,
65
- n_batch=8 # Reduced batch size for lower memory usage
66
  )
 
 
 
67
 
68
- # Clean up the generated text
69
- poem = generated_text.replace(prompt, "").strip()
70
 
71
- return PoetryResponse(poem=poem)
72
-
 
 
 
73
  except Exception as e:
74
- raise HTTPException(status_code=500, detail=f"Error generating poem: {str(e)}")
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
+ from transformers import AutoModelForCausalLM
5
+ import time
 
6
 
7
  # Initialize FastAPI app
8
  app = FastAPI(title="Poetry Generator")
 
16
  allow_headers=["*"],
17
  )
18
 
19
+ # Initialize the model (lazy loading)
 
 
 
20
  model = None
21
 
22
+ def load_model():
23
+ global model
24
+ if model is None:
25
+ # Load a quantized GGUF model
26
+ # You can download models from huggingface.co
27
+ # Example: GPT2 or Llama-2-7b-chat.Q4_K_M.gguf
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ "TheBloke/Llama-2-7B-Chat-GGUF",
30
+ model_file="llama-2-7b-chat.q4_K_M.gguf",
31
+ model_type="llama",
32
+ max_new_tokens=256,
33
+ context_length=512,
34
+ gpu_layers=0 # CPU only
35
+ )
36
+
37
  class PoetryRequest(BaseModel):
38
+ prompt: str
39
+ style: str = "free verse"
40
+ max_length: int = 200
41
 
42
  class PoetryResponse(BaseModel):
43
  poem: str
44
+ generation_time: float
45
 
46
  @app.on_event("startup")
47
  async def startup_event():
48
+ load_model()
 
 
 
 
 
 
 
 
49
 
50
  @app.post("/generate_poem", response_model=PoetryResponse)
51
  async def generate_poem(request: PoetryRequest):
 
 
 
52
  try:
53
+ start_time = time.time()
 
 
 
54
 
55
+ # Construct the prompt
56
+ full_prompt = f"""Write a {request.style} poem about {request.prompt}.
57
+ Make it creative and meaningful. The poem should be:
58
+
59
+ """
60
+
61
  # Generate the poem
62
+ output = model(
63
+ full_prompt,
64
+ max_new_tokens=request.max_length,
65
+ temperature=0.7,
66
+ top_p=0.95,
67
+ repeat_penalty=1.2
 
 
68
  )
69
+
70
+ # Clean up the output
71
+ poem = output.strip()
72
 
73
+ generation_time = time.time() - start_time
 
74
 
75
+ return PoetryResponse(
76
+ poem=poem,
77
+ generation_time=generation_time
78
+ )
79
+
80
  except Exception as e:
81
+ raise HTTPException(status_code=500, detail=str(e))
models/gpt4all-lora-quantized-ggml.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d9af98b0350fc8af7211097e816ffbb8bae9a18f8aea8c50ff94a99bd6cb2c7c
3
- size 4212860154