File size: 1,502 Bytes
2ce5005
 
 
 
 
 
 
24db546
 
 
2ce5005
24db546
 
 
 
2ce5005
24db546
 
2ce5005
24db546
 
2ce5005
24db546
 
 
 
2ce5005
 
 
24db546
2ce5005
 
24db546
2ce5005
 
 
24db546
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
import os
import streamlit as st
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

peft_model_id = "fiona/to_onion_news_converter"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=False)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id)

def make_inference(news_headline):
  batch = tokenizer(f"### INSTRUCTION\nBelow is a standard news headline, please rewrite it in a satirical style .\n\n### Standard:\n{news_headline}\n\n### new news:\n", return_tensors='pt')

  with torch.cuda.amp.autocast():
    output_tokens = model.generate(**batch, max_new_tokens=200)
      
  return tokenizer.decode(output_tokens[0], skip_special_tokens=True)

if __name__ == "__main__":
    # Title of the web application
    st.title('Onion news converter')
    
    # Text input widget
    user_input = st.text_input('Enter a news headline', '')
    
    # Displaying output directly below the input field
    if user_input:
        st.write('The onion style:', make_inference(user_input))