KaushikShresth12 commited on
Commit
8095d1b
1 Parent(s): b85b2f9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +39 -40
main.py CHANGED
@@ -1,24 +1,13 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from huggingface_hub import InferenceClient
4
  import uvicorn
5
 
 
 
6
 
7
- app = FastAPI()
8
 
9
  API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
10
 
11
-
12
- class Item(BaseModel):
13
- prompt: str
14
- history: list
15
- temperature: float = 0.1
16
- max_new_tokens: int = 2
17
- top_p: float = 0.15
18
- repetition_penalty: float = 1.0
19
- instructions: str = ""
20
- api: str = ""
21
-
22
  def format_prompt(message, custom_instructions=None):
23
  prompt = ""
24
  if custom_instructions:
@@ -26,32 +15,42 @@ def format_prompt(message, custom_instructions=None):
26
  prompt += f"[INST] {message} [/INST]"
27
  return prompt
28
 
29
- def generate(item: Item):
30
- try:
31
- temperature = float(item.temperature)
32
- if temperature < 1e-2:
33
- temperature = 1e-2
34
- top_p = float(item.top_p)
 
35
 
36
- generate_kwargs = dict(
37
  temperature=temperature,
38
- max_new_tokens=item.max_new_tokens,
39
  top_p=top_p,
40
- repetition_penalty=item.repetition_penalty,
41
  do_sample=True,
42
- seed=42,
43
- )
44
- print(item)
45
- custom_instructions=item.instructions
46
- formatted_prompt = format_prompt(item.prompt, custom_instructions)
47
- headers = {"Authorization": f"Bearer {item.api}"}
48
- client = InferenceClient(API_URL, headers=headers)
49
- response = client.text_generation(formatted_prompt, **generate_kwargs)
50
- return {"response": response}
51
- except Exception as e:
52
- return {"error": str(e)}
53
-
54
- @app.post("/generate/")
55
- async def generate_text(item: Item):
56
- return {"response": generate(item)}
57
-
 
 
 
 
 
 
 
 
 
 
1
+
 
 
2
  import uvicorn
3
 
4
+ from flask import Flask, request, jsonify
5
+ from huggingface_hub import InferenceClient
6
 
7
+ app = Flask(__name__)
8
 
9
  API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
10
 
 
 
 
 
 
 
 
 
 
 
 
11
  def format_prompt(message, custom_instructions=None):
12
  prompt = ""
13
  if custom_instructions:
 
15
  prompt += f"[INST] {message} [/INST]"
16
  return prompt
17
 
18
+ def Mistral7B(prompt, instructions, api, temperature=0.1, max_new_tokens=2, top_p=0.95, repetition_penalty=1.0):
19
+ global API_URL
20
+ try:
21
+ temperature = float(temperature)
22
+ if temperature < 1e-2:
23
+ temperature = 1e-2
24
+ top_p = float(top_p)
25
 
26
+ generate_kwargs = dict(
27
  temperature=temperature,
28
+ max_new_tokens=max_new_tokens,
29
  top_p=top_p,
30
+ repetition_penalty=repetition_penalty,
31
  do_sample=True,
32
+ seed=69,
33
+ )
34
+ custom_instructions = instructions
35
+ formatted_prompt = format_prompt(prompt, custom_instructions)
36
+
37
+ head = {"Authorization": f"Bearer {api}"}
38
+ client = InferenceClient(API_URL, headers=head)
39
+ response = client.text_generation(formatted_prompt, **generate_kwargs)
40
+ return response
41
+ except Exception as e:
42
+ return str(e)
43
+
44
+ @app.route("/generate-text", methods=["POST"])
45
+ def generate_text():
46
+ data = request.json
47
+ prompt = data.get("prompt")
48
+ instructions = data.get("instructions")
49
+ api_key = data.get("api_key")
50
+
51
+ if not prompt or not instructions or not api_key:
52
+ return jsonify({"error": "Missing required fields"}), 400
53
+
54
+ response = Mistral7B(prompt, instructions, api_key)
55
+
56
+ return jsonify({"response": response}), 200