File size: 2,020 Bytes
9b52308
5fd44e9
90d439d
6ed9cc0
 
 
 
 
 
 
 
4ac6668
 
 
6ed9cc0
664ba56
 
 
4ac6668
664ba56
4ac6668
9730359
 
 
4ac6668
 
9730359
9b52308
6ed9cc0
 
 
9b52308
5fd44e9
 
dad4228
 
 
 
9730359
 
 
 
4ac6668
9730359
6ed9cc0
 
 
 
6d70053
87fdf21
 
9730359
 
5fd44e9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import streamlit as st

def gptq_model_options():
    return [
        "TheBloke/Llama-2-7B-Chat-GPTQ", 
        "TheBloke/Llama-2-13B-chat-GPTQ",
        "TheBloke/meditron-7B-GPTQ",
        "TheBloke/meditron-70B-GPTQ",
    ]

loaded_model = None
loaded_model_name = ""

def get_llm_response(model_name_or_path, temperature, do_sample, top_p, top_k, max_new_tokens, repetition_penalty, formatted_prompt):
    global loaded_model
    global loaded_model_name
    
    if loaded_model != model_name_or_path:
        
        loaded_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                                device_map="auto",
                                                trust_remote_code=False,
                                                revision="main")
        loaded_model_name = model_name_or_path
        

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
            
    print("Formatted prompt:")
    print(formatted_prompt)
    
    st.session_state["llm_messages"].append(formatted_prompt)
    
    #print("\n\n*** Generate:")
    #input_ids = tokenizer(formatted_prompt, return_tensors='pt').input_ids.cuda()
    #output = model.generate(inputs=input_ids, temperature=temperature, do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens)
    #print(tokenizer.decode(output[0], skip_special_tokens=True)) 

    print("*** Pipeline:")
    pipe = pipeline(
        "text-generation",
        model=loaded_model,
        tokenizer=tokenizer,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        return_full_text=False
    )

    pipe_response = pipe(formatted_prompt)
    st.session_state["llm_messages"].append(pipe_response)
    print(pipe_response)
    return pipe_response[0]['generated_text']