import os | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
from transformers import LEDForConditionalGeneration, LEDTokenizer | |
from langchain_openai import OpenAI | |
# from huggingface_hub import login | |
from dotenv import load_dotenv | |
from logging import getLogger | |
# import streamlit as st | |
import torch | |
load_dotenv() | |
hf_token = os.environ.get("HF_TOKEN") | |
# # hf_token = st.secrets["HF_TOKEN"] | |
# login(token=hf_token) | |
logger = getLogger(__name__) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def get_local_model(model_name_or_path:str)->pipeline: | |
#print(f"Model is running on {device}") | |
tokenizer = LEDTokenizer.from_pretrained( #AutoTokenizer.from_pretrained( news changes to support led | |
model_name_or_path, | |
token = hf_token | |
) | |
model = LEDForConditionalGeneration.from_pretrained( #AutoModelForSeq2SeqLM.from_pretrained( new changes to support led | |
model_name_or_path, | |
torch_dtype=torch.float32, | |
token = hf_token | |
) | |
pipe = pipeline( | |
task = 'summarization', | |
model=model, | |
tokenizer=tokenizer, | |
device = device, | |
) | |
logger.info(f"Summarization pipeline created and loaded to {device}") | |
return pipe | |
def get_endpoint(api_key:str): | |
llm = OpenAI(openai_api_key=api_key) | |
return llm | |
def get_model(model_type,model_name_or_path,api_key = None): | |
if model_type == "openai": | |
return get_endpoint(api_key) | |
else: | |
return get_local_model(model_name_or_path) | |