novamysticX commited on
Commit
fca0532
1 Parent(s): 45ca055

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -10,17 +10,17 @@ logger = logging.getLogger(__name__)
10
 
11
  app = FastAPI(title="SQL Coder API")
12
 
13
- # Ensure cache directory exists
14
- cache_dir = os.getenv('TRANSFORMERS_CACHE', '/home/user/.cache/huggingface')
15
- os.makedirs(cache_dir, exist_ok=True)
16
 
17
  # Initialize pipeline
18
  try:
19
- pipe = pipeline("text-generation",
20
- model="defog/llama-3-sqlcoder-8b",
21
- device_map="auto",
22
- torch_dtype="auto",
23
- cache_dir=cache_dir)
 
24
  logger.info("Pipeline initialized successfully")
25
  except Exception as e:
26
  logger.error(f"Error initializing pipeline: {str(e)}")
@@ -32,7 +32,7 @@ class ChatMessage(BaseModel):
32
 
33
  class QueryRequest(BaseModel):
34
  messages: list[ChatMessage]
35
- max_length: int = 1024
36
  temperature: float = 0.7
37
 
38
  class QueryResponse(BaseModel):
@@ -47,10 +47,11 @@ async def generate(request: QueryRequest):
47
  # Generate response using pipeline
48
  response = pipe(
49
  formatted_prompt,
50
- max_length=request.max_length,
51
  temperature=request.temperature,
52
  do_sample=True,
53
- num_return_sequences=1
 
54
  )
55
 
56
  # Extract generated text
 
10
 
11
  app = FastAPI(title="SQL Coder API")
12
 
13
+ # Set environment variable for cache directory
14
+ os.environ['TRANSFORMERS_CACHE'] = '/home/user/.cache/huggingface'
 
15
 
16
  # Initialize pipeline
17
  try:
18
+ pipe = pipeline(
19
+ "text-generation",
20
+ model="defog/llama-3-sqlcoder-8b",
21
+ device_map="auto",
22
+ model_kwargs={"torch_dtype": "auto"}
23
+ )
24
  logger.info("Pipeline initialized successfully")
25
  except Exception as e:
26
  logger.error(f"Error initializing pipeline: {str(e)}")
 
32
 
33
  class QueryRequest(BaseModel):
34
  messages: list[ChatMessage]
35
+ max_new_tokens: int = 1024
36
  temperature: float = 0.7
37
 
38
  class QueryResponse(BaseModel):
 
47
  # Generate response using pipeline
48
  response = pipe(
49
  formatted_prompt,
50
+ max_new_tokens=request.max_new_tokens,
51
  temperature=request.temperature,
52
  do_sample=True,
53
+ num_return_sequences=1,
54
+ pad_token_id=pipe.tokenizer.eos_token_id
55
  )
56
 
57
  # Extract generated text