File size: 3,275 Bytes
0550bc2
16b7e9b
 
 
0550bc2
 
 
 
 
16b7e9b
a9a63d4
0550bc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16b7e9b
 
0550bc2
73d33c9
0550bc2
73d33c9
 
0550bc2
16b7e9b
7138786
bf50753
7138786
0550bc2
7138786
c9f0850
 
 
 
 
 
 
 
7138786
 
 
 
c9f0850
7138786
c9f0850
7138786
002d8ef
7138786
 
 
c9f0850
002d8ef
c9f0850
002d8ef
7138786
 
 
 
 
 
 
 
 
 
 
 
c9f0850
7138786
 
 
0550bc2
7138786
c9f0850
7138786
0550bc2
7138786
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import Dict, List, Any
#from unsloth import FastLanguageModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch import bfloat16



class EndpointHandler():
    max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
    dtype = bfloat16 # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
    load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
    max_new_tokens=500

    summary_prompt = """
        We will provide you an introduction to scientific research articles in the medical field, delimited by '<article>' tags. 
        Your mission is to produce a small summary of this introduction in a short paragraph of 200 characters maximum.

        ### Input: 
        <article>{}</article> 

        ### Response:
        {}
    """


    def __init__(self, path=""):
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            temperature=0,
#            torch_dtype = self.dtype,
            load_in_4bit = self.load_in_4bit,
        )
#        .to("cuda")

        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.tokenizer.padding_side="left"
        self.tokenizer.pad_token= self.tokenizer.eos_token
          

    def _secure_inputs(self, data: Dict[str, Any]):
        if not isinstance(data, dict):
          return [{"error": "inputs should be shaped like {'temperature': float, 'inputs': <string or List of strings (abstracts)>}"}], False

        if not 'inputs' in data:
          return [{"error": "inputs should be shaped like {'temperature': float, 'inputs': <string or List of strings (abstracts)>}"}], False

        temperature = data.get("temperature", 0.01)
        inputs = data["inputs"]

        if isinstance(inputs, str):
            inputs = [inputs]

        return inputs, temperature, True

    
    def _format_inputs(self, inputs: list[str]):
        prompts = [self.summary_prompt.format(abstract, "") for abstract in inputs]
        prompts_lengths = [len(prompt) for prompt in prompts]
        return prompts, prompts_lengths

    def _generate_outputs(self, inputs, temperature):
        tokenized = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
        outputs = self.model.generate(**tokenized, temperature=temperature, max_new_tokens=500, use_cache=True)
        decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) 
        return decoded

    def _format_outputs(self, outputs: list[str], inputs_lengths: list[int]):
        decoded_without_input = [
            output_str[input_len:].strip()
            for output_str, input_len
            in zip(outputs, inputs_lengths)
        ]
        return decoded_without_input
        
        
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        inputs, temperature, is_secure = self._secure_inputs(data)
        
        if not is_secure:
            return inputs
        
        inputs, inputs_length = self._format_inputs(inputs)
        outputs = self._generate_outputs(inputs, temperature)
        outputs = self._format_outputs(outputs, inputs_length)

        outputs = [{"summary": output_} for output_ in outputs]
        
        return outputs