File size: 1,995 Bytes
5fa76ab
 
 
 
 
 
 
 
c1afe66
5fa76ab
 
 
c760d77
 
51c223e
 
 
 
fe14762
5fa76ab
ac96e13
 
 
 
 
 
 
 
5fa76ab
 
51c223e
 
 
 
 
 
 
 
 
 
 
 
421e5fe
5fa76ab
c1afe66
c760d77
29b2045
c760d77
c1afe66
cb4fa58
85c980a
 
 
 
 
 
c1afe66
5fa76ab
 
85c980a
5fa76ab
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn


app = FastAPI()

client = InferenceClient("FacebookAI/roberta-large-mnli")

class Item(BaseModel):
    prompt: str
    #history: list
    #system_prompt: str
    #temperature: float = 0.0
    #max_new_tokens: int = 1048
    #top_p: float = 0.15
    #repetition_penalty: float = 1.0
    #trust_remote_code = True

#def format_prompt(message, history):
#    prompt = "<s>"
#    for user_prompt, bot_response in history:
#        prompt += f"[INST] {user_prompt} [/INST]"
#        prompt += f" {bot_response}</s> "
#    prompt += f"[INST] {message} [/INST]"
#    return prompt


def generate(item: Item):
    #temperature = float(item.temperature)
    #if temperature < 1e-2:
     #   temperature = 1e-2
    #top_p = float(item.top_p)

    #generate_kwargs = dict(
    #    temperature=temperature,
    #    max_new_tokens=item.max_new_tokens,
    #    top_p=top_p,
    #    repetition_penalty=item.repetition_penalty,
    #    do_sample=True,
    #    seed=42,
   # )

    #formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
    #text = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
    text = item.prompt
    print(text)
    labels = ["Requirement", "Information"]
    print(labels)
    result = client.zero_shot_classification("The car shall be slow.", labels)
    print("Predicted Labels:")
    print(result["labels"][0], result["scores"][0])
    print(result["labels"][1], result["scores"][1])
    #stream = client.zero_shot_classification(text, labels)
    #print("Stream: " + stream)
    #stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in result:
        output += response.token.text
    return output

@app.post("/generate/")
async def generate_text(item: Item):
    return {"response": generate(item)}