import torch from transformers import AutoModelForCausalLM, AutoTokenizer from conversation import get_default_conv_template import base64 import streamlit as st def generate(style, topic, words=200, sender='Sender_Name', recipient='Recipient_Name'): tokenizer = AutoTokenizer.from_pretrained("GeneZC/MiniChat-3B", use_fast=False) device = 'cuda' if torch.cuda.is_available() else 'cpu' def get_model(device): if device == 'cuda': return AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", use_cache=True, device_map="auto", torch_dtype=torch.float16).eval() else: return AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", use_cache=True, device_map="cpu", torch_dtype=torch.float32).eval() model = get_model(device) conv = get_default_conv_template("minichat") question = f""" Generate an email with the following specifications and make sure to highlight the subject, according to the topic of the mail, in the very first line of the response: - Style: {style} - Word Limit: {words} - Topic: {topic} - Sender Name: {sender} - Recipient Name: {recipient} """ conv.append_message(conv.roles[0], question) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer([prompt]).input_ids output_ids = model.generate( torch.as_tensor(input_ids).to(device), do_sample=True, temperature=0.7, max_new_tokens=2000, ) output_ids = output_ids[0][len(input_ids[0]):] output = tokenizer.decode(output_ids, skip_special_tokens=True).strip() return output @st.cache_data def get_image(file): with open(file, "rb") as f: data = f.read() return base64.b64encode(data).decode() with open("./static/style.css") as style: st.markdown(f"", unsafe_allow_html=True) img = get_image("./static/bg.png") bg_image = f""" """ st.markdown(bg_image, unsafe_allow_html=True) st.title("MailCraft") st.subheader("Unlock Seamless Email Excellence 📧 for Effortless Communication") style = st.text_input("Enter Email Type", placeholder="Ex. Professional/Personal/Job Application") words = st.number_input("Enter Word Limit", min_value=100, max_value=500, value=None, placeholder="From 100 to 500") topic = st.text_input("Enter Email Topic") sender = st.text_input("Enter Sender Name") recipient = st.text_input("Enter Recipient Name") if st.button("Generate Email"): generated_email = generate(style, topic, words, sender, recipient) st.write(generated_email)