File size: 1,193 Bytes
99e744f
 
 
 
 
 
 
 
 
00742e9
 
 
99e744f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from langchain_openai import OpenAI
from huggingface_hub import login
from dotenv import load_dotenv
from logging import getLogger
import streamlit as st
import torch

load_dotenv()
hf_token = os.environ.get("HF_TOKEN")
# hf_token = st.secrets["HF_TOKEN"]
login(token=hf_token)
logger = getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"

def get_local_model(model_name_or_path:str)->pipeline:

    #print(f"Model is running on {device}")

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
    pipe = pipeline(
        task = 'summarization',
        model=model,
        tokenizer=tokenizer,
        device = device,
    )

    logger.info(f"Summarization pipeline created and loaded to {device}")
   
    return pipe

def get_endpoint(api_key:str):

    llm = OpenAI(openai_api_key=api_key)
    return llm

def get_model(model_type,model_name_or_path,api_key = None):
    if model_type == "openai":
        return get_endpoint(api_key)
    else: 
        return get_local_model(model_name_or_path)