HFHAB commited on
Commit
eaea91c
1 Parent(s): 73159e6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +20 -5
main.py CHANGED
@@ -7,7 +7,19 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
 
8
  app = FastAPI()
9
 
10
- client = InferenceClient("HFHAB/FinetunedMistralModel")
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class Item(BaseModel):
13
  prompt: str
@@ -46,11 +58,14 @@ def generate(item: Item):
46
  )
47
 
48
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
49
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
50
- output = ""
 
 
 
51
 
52
- for response in stream:
53
- output += response.token.text
54
  return output
55
 
56
  @app.post("/generate/")
 
7
 
8
  app = FastAPI()
9
 
10
+ model_id = "mistralai/Mistral-7B-v0.1
11
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
12
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
13
+
14
+ #prompt = "<s>[INST] Write a tweet on future of AI [/INST]"
15
+ inputs = tokenizer(prompt, return_tensors="pt").to(0)
16
+
17
+ out = model.generate(**inputs, max_new_tokens=250, temperature = 0.6, top_p=0.95, tok_k=40)
18
+
19
+ print(tokenizer.decode(out[0], skip_special_tokens=True))
20
+
21
+
22
+ #client = InferenceClient("HFHAB/FinetunedMistralModel")
23
 
24
  class Item(BaseModel):
25
  prompt: str
 
58
  )
59
 
60
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
61
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(0)
62
+ out = model.generate(**inputs, max_new_tokens=250, temperature = 0.6, top_p=0.95, tok_k=40)
63
+ output = tokenizer.decode(out[0], skip_special_tokens=True)
64
+ #stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
65
+ #output = ""
66
 
67
+ #for response in stream:
68
+ # output += response.token.text
69
  return output
70
 
71
  @app.post("/generate/")