Prabhash commited on
Commit
35c95d7
1 Parent(s): 9d2536a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +19 -25
main.py CHANGED
@@ -1,32 +1,26 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
- import os
3
- from fastapi import FastAPI
4
  from pydantic import BaseModel
5
 
6
- # Load Gemma-2B tokenizer and model
7
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=os.environ["token"])
8
- model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", token=os.environ["token"])
 
 
 
9
 
10
- # Define Pydantic object for input validation
11
- class GenerationRequest(BaseModel):
12
  prompt: str
13
- repeat_penalty: float = 1.0 # Default repeat penalty
14
 
15
- # Initialize FastAPI
16
  app = FastAPI()
17
 
18
- # Define route for generating text
19
- @app.post("/generate_text")
20
- async def generate_text(request: GenerationRequest):
21
- # Tokenize the input prompt
22
- input_prompt = "<s>Below is an instruction that describes a task. Write a response that appropriately completes the request.</s>" + "\n" + "<s>" + request.prompt + "</s>"
23
-
24
- # Encode the input prompt
25
- input_ids = tokenizer.encode(input_prompt, return_tensors="pt")
26
-
27
- # Generate text based on the input prompt
28
- outputs = model.generate(input_ids, repeat_penalty=request.repeat_penalty)
29
-
30
- # Decode the generated output and return
31
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
32
- return {"generated_text": generated_text}
 
1
+ from ctransformers import AutoModelForCausalLM
2
+ from fastapi import FastAPI, Form
 
3
  from pydantic import BaseModel
4
 
5
+ #Model loading
6
+ llm = AutoModelForCausalLM.from_pretrained("zephyr-7b-beta.Q4_K_S.gguf",
7
+ model_type='mistral',
8
+ max_new_tokens = 1096,
9
+ threads = 3,
10
+ )
11
 
12
+ #Pydantic object
13
+ class validation(BaseModel):
14
  prompt: str
 
15
 
16
+ #Fast API
17
  app = FastAPI()
18
 
19
+ #Zephyr completion
20
+ @app.post("/llm_on_cpu")
21
+ async def stream(item: validation):
22
+ system_prompt = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
23
+ E_INST = "</s>"
24
+ user, assistant = "<|user|>", "<|assistant|>"
25
+ prompt = f"{system_prompt}{E_INST}\n{user}\n{item.prompt.strip()}{E_INST}\n{assistant}\n"
26
+ return llm(prompt)