PinkAlpaca commited on
Commit
b651d10
·
verified ·
1 Parent(s): a0cab94

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +5 -3
main.py CHANGED
@@ -15,9 +15,8 @@ if not gemini_api_key:
15
  # Make sure to use a model you have access to
16
  generator: TextGenerationPipeline = pipeline(
17
  "text-generation",
18
- model="llama-duo/gemma7b-summarize-gemini1.5flash-80k", # Replace if needed
19
- use_auth_token=gemini_api_key
20
- )
21
 
22
  # Data model for the request body
23
  class Item(BaseModel):
@@ -32,11 +31,14 @@ async def generate_text(item: Item):
32
  if not item.prompt:
33
  raise HTTPException(status_code=400, detail="`prompt` field is required")
34
 
 
 
35
  output = generator(
36
  item.prompt,
37
  temperature=item.temperature,
38
  max_length=item.max_new_tokens,
39
  )
 
40
 
41
  return {"generated_text": output[0]['generated_text']}
42
 
 
15
  # Make sure to use a model you have access to
16
  generator: TextGenerationPipeline = pipeline(
17
  "text-generation",
18
+ model="gemini-1.5-flash", # Replace if needed
19
+ ) # IMPORTANT: **DO NOT** set `use_auth_token` here
 
20
 
21
  # Data model for the request body
22
  class Item(BaseModel):
 
31
  if not item.prompt:
32
  raise HTTPException(status_code=400, detail="`prompt` field is required")
33
 
34
+ # Set API key in the headers BEFORE calling the pipeline
35
+ generator.model.config.use_auth_token = gemini_api_key # Set the API key here
36
  output = generator(
37
  item.prompt,
38
  temperature=item.temperature,
39
  max_length=item.max_new_tokens,
40
  )
41
+ generator.model.config.use_auth_token = None # Reset after use
42
 
43
  return {"generated_text": output[0]['generated_text']}
44