File size: 3,388 Bytes
e53fb7b
 
5dab16f
e53fb7b
 
 
 
 
83be1bc
5dab16f
 
4e06fbd
e53fb7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04dbc8e
e53fb7b
 
 
 
 
04dbc8e
e53fb7b
04dbc8e
e53fb7b
0abcce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57766c5
0abcce5
 
 
 
 
 
e53fb7b
0abcce5
e53fb7b
0abcce5
e53fb7b
0abcce5
 
e53fb7b
0abcce5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from fastapi import FastAPI, HTTPException
from typing import Any
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
from json_repair import repair_json
import nltk
import sys

app = FastAPI()

nltk.download('punkt')

tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

HF_TOKEN = getenv("HF_TOKEN")
MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
FALLBACK_MODELS = [
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
]

class InputData(BaseModel):
    model: str
    system_prompt_template: str
    prompt_template: str
    system_prompt: str
    user_input: str
    json_prompt: str
    history: str = ""

@app.post("/generate-response/")
async def generate_response(data: InputData) -> Any:
    client = InferenceClient(model=data.model, token=HF_TOKEN)

    sentences = tokenizer.tokenize(data.user_input)
    data_dict = {'###New response###': [], '###Sentence count###': 0}
    for i, sentence in enumerate(sentences):
        data_dict["###New response###"].append(sentence)
        data_dict["###Sentence count###"] = i + 1

    data.history += data.prompt_template.replace("{Prompt}", str(data_dict))

    inputs = (
        data.system_prompt_template.replace("{SystemPrompt}",
                                       data.system_prompt) +
        data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) +
        data.history)

    seed = random.randint(0, 2**32 - 1)

    models_to_try = [data.model] + FALLBACK_MODELS

    for model in models_to_try:
        try:
            response = client.text_generation(inputs,
                                              temperature=1.0,
                                              max_new_tokens=1000,
                                              seed=seed)

            strict_response = str(response)
            
            repaired_response = repair_json(strict_response,
                                            return_objects=True)

            if isinstance(repaired_response, str):
                raise HTTPException(status_code=500, detail="Invalid response from model")
            else:
                cleaned_response = {}
                for key, value in repaired_response.items():
                    cleaned_key = key.replace("###", "")
                    cleaned_response[cleaned_key] = value

                for i, text in enumerate(cleaned_response["New response"]):
                    if i <= 2:
                        sentences = tokenizer.tokenize(text)
                        if sentences:
                            cleaned_response["New response"][i] = sentences[0]
                    else:
                        del cleaned_response["New response"][i]
                if cleaned_response.get("Sentence count"):
                    if cleaned_response["Sentence count"] > 3:
                        cleaned_response["Sentence count"] = 3
                else:
                    cleaned_response["Sentence count"] = len(cleaned_response["New response"])

                data.history += str(cleaned_response)

                return cleaned_response

        except Exception as e:
            print(f"Model {model} failed with error: {e}")

    raise HTTPException(status_code=500, detail="All models failed to generate response")