File size: 1,053 Bytes
0809507
 
 
3b7cf58
0809507
3b7cf58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0809507
 
3b7cf58
0809507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import Any, List, Mapping, Optional

from langchain.llms.base import LLM
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_name = "bigscience/bloom-560m" 
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, config='T5Config')

pl = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    # device=0, # GPU device number
    # max_length=512,
    do_sample=True,
    top_p=0.95,
    top_k=50,
    temperature=0.7
)

class CustomLLM(LLM):
    pipeline = pl

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        prompt_length = len(prompt)
        response = self.pipeline(prompt, max_new_tokens=525)[0]["generated_text"]

        # only return newly generated tokens
        return response[prompt_length:]

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return {"name_of_model": self.model_name}

    @property
    def _llm_type(self) -> str:
        return "custom"