File size: 3,798 Bytes
e53fb7b
f319ec8
5dab16f
e53fb7b
 
 
 
d83f45c
 
5dab16f
 
4e06fbd
e53fb7b
 
 
 
 
 
 
 
84a6fe8
 
db9c88b
84a6fe8
 
e53fb7b
d17ac0e
 
e53fb7b
d83f45c
 
 
 
407b03e
344f4fe
d17ac0e
 
 
84a6fe8
 
d17ac0e
63c8a4b
 
d17ac0e
 
e53fb7b
63c8a4b
 
d17ac0e
84a6fe8
f319ec8
84a6fe8
 
d17ac0e
84a6fe8
 
04dbc8e
f319ec8
84a6fe8
 
f319ec8
04dbc8e
e53fb7b
04dbc8e
9db93ec
d17ac0e
344f4fe
 
 
 
 
 
0abcce5
cdca6a5
 
d17ac0e
 
 
 
63c8a4b
 
d17ac0e
63c8a4b
 
cdca6a5
63c8a4b
cdca6a5
 
63c8a4b
 
cdca6a5
 
9db93ec
 
c81e2b1
d83f45c
 
 
 
 
 
 
 
 
 
 
 
 
680bdc2
e44ac21
680bdc2
 
 
 
 
 
 
 
 
 
 
 
 
d83f45c
 
 
 
 
84a6fe8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from fastapi import FastAPI, HTTPException
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
import nltk
import re
from word_forms.word_forms import get_word_forms

app = FastAPI()

nltk.download('punkt')

tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

HF_TOKEN = getenv("HF_TOKEN")

class InputData(BaseModel):
    model: str
    system_prompt_template: str
    prompt_template: str
    end_token: str
    system_prompts: List[str]
    user_inputs: List[str]
    history: str = ""
    segment: bool = False
    max_sentences: Optional[int] = None

class WordCheckData(BaseModel):
    string: str
    word: str

@app.post("/generate-response/")
async def generate_response(data: InputData) -> Dict[str, Any]:
    if data.max_sentences is not None and data.max_sentences != 0:
        data.segment = True
    elif data.max_sentences == 0:
        for user_input in data.user_inputs:
            data.history += data.prompt_template.replace("{Prompt}", user_input)
        return {
            "response": [],
            "sentence_count": None,
            "history": data.history + data.end_token
        }

    responses = []

    if data.segment:
        for user_input in data.user_inputs:
            user_sentences = tokenizer.tokenize(user_input)
            user_input_str = "\n".join(user_sentences)
            data.history += data.prompt_template.replace("{Prompt}", user_input_str) + "\n"
    else:
        for user_input in data.user_inputs:
            data.history += data.prompt_template.replace("{Prompt}", user_input) + "\n"

    inputs = ""
    for system_prompt in data.system_prompts:
        inputs += data.system_prompt_template.replace("{SystemPrompt}", system_prompt) + "\n"
    inputs += data.history

    seed = random.randint(0, 2**32 - 1)

    try:
        client = InferenceClient(model=data.model, token=HF_TOKEN)
        response = client.text_generation(
            inputs,
            temperature=1.0,
            max_new_tokens=1000,
            seed=seed
        )

        response_str = str(response)

        if data.segment:
            ai_sentences = tokenizer.tokenize(response_str)
            if data.max_sentences is not None:
                ai_sentences = ai_sentences[:data.max_sentences]
            responses = ai_sentences
            sentence_count = len(ai_sentences)
        else:
            responses = [response_str]
            sentence_count = None

        data.history += response_str + "\n"

        return {
            "response": responses,
            "sentence_count": sentence_count,
            "history": data.history + data.end_token
        }

    except Exception as e:
        print(f"Model {data.model} failed with error: {e}")
        raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response")

@app.post("/check-word/")
async def check_word(data: WordCheckData) -> Dict[str, Any]:
    input_string = data.string.lower()
    word = data.word.lower()

    forms = get_word_forms(word)

    all_forms = set()
    for words in forms.values():
        all_forms.update(words)

    # Initialize found flag
    found = False
    
    # Split the input string into words
    input_words = input_string.split()
    
    # Loop through each word in the input string
    for input_word in input_words:
        # Strip the word to contain only alphabetic characters
        input_word = ''.join(filter(str.isalpha, input_word))
        
        # Check if the stripped word is equal to any of the forms
        if input_word in all_forms:
            found = True
            break  # Exit loop if word is found

    result = {
        "found": found
    }

    return result