abhisheksan commited on
Commit
cee4b22
1 Parent(s): 2c9446c

Add initial project structure with FastAPI and poetry generation service

Browse files
__pycache__/main.cpython-312.pyc ADDED
Binary file (1.88 kB). View file
 
app/api/endpoints/__pycache__/poetry.cpython-312.pyc ADDED
Binary file (2.59 kB). View file
 
app/api/endpoints/poetry.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from pydantic import BaseModel, Field
3
+ from typing import Optional
4
+ from ...services.poetry_generation import PoetryGenerationService
5
+
6
+ router = APIRouter()
7
+
8
+ class PoemRequest(BaseModel):
9
+ prompt: str = Field(..., description="The prompt for poem generation")
10
+ temperature: Optional[float] = Field(0.7, ge=0.1, le=2.0)
11
+ top_p: Optional[float] = Field(0.9, ge=0.1, le=1.0)
12
+ top_k: Optional[int] = Field(50, ge=1, le=100)
13
+ max_length: Optional[int] = Field(200, ge=50, le=500)
14
+ repetition_penalty: Optional[float] = Field(1.1, ge=1.0, le=2.0)
15
+
16
+ class PoemResponse(BaseModel):
17
+ poem: str
18
+ parameters_used: dict
19
+
20
+ @router.post("/generate", response_model=PoemResponse)
21
+ async def generate_poem(request: PoemRequest):
22
+ try:
23
+ service = PoetryGenerationService()
24
+ poem = await service.generate_poem(
25
+ prompt=request.prompt,
26
+ temperature=request.temperature,
27
+ top_p=request.top_p,
28
+ top_k=request.top_k,
29
+ max_length=request.max_length,
30
+ repetition_penalty=request.repetition_penalty
31
+ )
32
+
33
+ return PoemResponse(
34
+ poem=poem,
35
+ parameters_used={
36
+ "temperature": request.temperature,
37
+ "top_p": request.top_p,
38
+ "top_k": request.top_k,
39
+ "max_length": request.max_length,
40
+ "repetition_penalty": request.repetition_penalty
41
+ }
42
+ )
43
+ except Exception as e:
44
+ raise HTTPException(status_code=500, detail=str(e))
app/core/__pycache__/config.cpython-312.pyc ADDED
Binary file (155 Bytes). View file
 
app/core/config.py ADDED
File without changes
app/core/models.py ADDED
File without changes
app/services/__pycache__/poetry_generation.cpython-312.pyc ADDED
Binary file (3.15 kB). View file
 
app/services/poetry_generation.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ class PoetryGenerationService:
6
+ def __init__(self):
7
+ model_name = "meta-llama/Llama-3.2-3B-Instruct" # Adjust model name as needed
8
+
9
+ # Initialize tokenizer and model
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ self.model = AutoModelForCausalLM.from_pretrained(
12
+ model_name,
13
+ torch_dtype=torch.float16, # Use float16 for efficiency
14
+ device_map="auto" # Automatically handle device placement
15
+ )
16
+
17
+ # Set model to evaluation mode
18
+ self.model.eval()
19
+
20
+ async def generate_poem(
21
+ self,
22
+ prompt: str,
23
+ temperature: Optional[float] = 0.7,
24
+ top_p: Optional[float] = 0.9,
25
+ top_k: Optional[int] = 50,
26
+ max_length: Optional[int] = 200,
27
+ repetition_penalty: Optional[float] = 1.1
28
+ ) -> str:
29
+ try:
30
+ # Tokenize the input prompt
31
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
32
+
33
+ # Move input tensors to the same device as the model
34
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
35
+
36
+ # Generate text with the specified parameters
37
+ with torch.no_grad():
38
+ outputs = self.model.generate(
39
+ inputs["input_ids"],
40
+ attention_mask=inputs["attention_mask"],
41
+ do_sample=True,
42
+ temperature=temperature,
43
+ top_p=top_p,
44
+ top_k=top_k,
45
+ max_length=max_length,
46
+ repetition_penalty=repetition_penalty,
47
+ pad_token_id=self.tokenizer.eos_token_id,
48
+ eos_token_id=self.tokenizer.eos_token_id,
49
+ )
50
+
51
+ # Decode the generated text
52
+ generated_text = self.tokenizer.decode(
53
+ outputs[0],
54
+ skip_special_tokens=True,
55
+ clean_up_tokenization_spaces=True
56
+ )
57
+
58
+ return generated_text
59
+
60
+ except Exception as e:
61
+ raise Exception(f"Error generating poem: {str(e)}")
62
+
63
+ def __del__(self):
64
+ # Clean up resources
65
+ try:
66
+ del self.model
67
+ del self.tokenizer
68
+ torch.cuda.empty_cache() # If using GPU
69
+ except:
70
+ pass
app/utils/__pycache__/text_processing.cpython-312.pyc ADDED
Binary file (878 Bytes). View file
 
app/utils/text_processing.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def preprocess_input(mood, theme, style):
4
+ """Preprocess the user's mood, theme, and style input."""
5
+ input_text = f"Write a poem with a {mood} mood, about the theme of {theme}, in the style of {style}."
6
+ input_text = preprocess_text(input_text)
7
+ return input_text
8
+
9
+ def preprocess_text(text):
10
+ """Preprocess the input text (e.g., clean, tokenize)."""
11
+ text = re.sub(r'\s+', ' ', text.strip().lower())
12
+ return text
main.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from app.api.endpoints.poetry import router as poetry_router
3
+
4
+ app = FastAPI()
5
+ app.include_router(poetry_router, prefix="/api/v1/poetry")
6
+
7
+ if __name__ == "__main__":
8
+ import os
9
+ import logging
10
+ from typing import Tuple
11
+ from starlette.applications import Starlette
12
+ from starlette.responses import Response
13
+ from starlette.routing import Route
14
+ from starlette.staticfiles import StaticFiles
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ def get_app_and_port() -> Tuple[Starlette, int]:
20
+ port = int(os.getenv("PORT", "8000"))
21
+ return app, port
22
+
23
+ async def lifecheck(request):
24
+ return Response("OK", media_type="text/plain")
25
+
26
+ routes = [
27
+ Route("/", app.router),
28
+ Route("/healthz", lifecheck),
29
+ ]
30
+
31
+ app_and_port = get_app_and_port()
32
+ app = app_and_port[0]
33
+ port = app_and_port[1]
34
+
35
+ logger.info(f"Starting FastAPI server on port {port}")
36
+ app.mount("/static", StaticFiles(directory="static"), name="static")
37
+ app.run(host="0.0.0.0", port=port)
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ gradio==4.19.2
3
+ fastapi==0.109.2
4
+ uvicorn==0.27.1
5
+ pydantic==2.6.1
6
+
7
+ # For API requests and handling
8
+ requests==2.31.0
9
+ python-multipart==0.0.9
10
+
11
+ # For model handling
12
+ torch==2.2.0
13
+ transformers==4.37.2
14
+ accelerate==0.27.2
15
+
16
+ # Optional but recommended for performance
17
+ python-jose==3.3.0 # for JWT handling if you add auth later
18
+ gunicorn==21.2.0 # for production deployment
19
+ python-dotenv==1.0.0 # for environment variables