storyGeneration / app.py
Thomslionel's picture
Update app.py
02f309f verified
import os
os.system('sh setup.sh')
import streamlit as st
import torch
from unsloth import FastLanguageModel
from transformers import TextStreamer
# Configuration du modèle
max_seq_length = 2048
dtype = None
load_in_4bit = True
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="Thomslionel/mistral_for_story_generation",
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)
model = FastLanguageModel.for_inference(model)
EOS_TOKEN = tokenizer.eos_token
INSTRUCTION_KEY = "### Instruction:"
RESPONSE_KEY = "### Réponse:"
END_KEY = EOS_TOKEN
INTRO_BLURB = "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète adéquatement la demande."
PROMPT_FOR_GENERATION_FORMAT = """{intro}
{instruction_key}
{instruction}
{response_key}
{end_key}
""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
response_key=RESPONSE_KEY,
end_key=END_KEY,
)
# Fonction pour créer le prompt avec le titre
def create_prompt_with_title(titre):
instruction = (
f"Tu es un écrivain talentueux spécialisé dans les histoires culturelles Burkinabè comiques et narratives. "
f"Ton objectif est de créer une histoire captivante, drôle, culturelle et cohérente. "
f"Le titre de l'histoire est : {titre}. "
f"Assure-toi d'inclure des éléments humoristiques, des personnages intéressants et une intrigue bien développée. "
f"Voici le titre : {titre}. Bonne écriture !"
)
return f"### Instruction:\n{instruction}\n### Response:\n"
# Interface Streamlit
st.title("Générateur d'Histoires Burkinabè")
titre = st.text_input("Entrez le titre de l'histoire", "M'BA SOAMBA (le lièvre) et M'BA BAAGA (le chien)")
if st.button("Générer l'histoire"):
prompt = create_prompt_with_title(titre)
inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
text_streamer = TextStreamer(tokenizer)
with st.spinner("Génération de l'histoire..."):
generated_text = model.generate(
**inputs,
streamer=text_streamer,
max_new_tokens=2048,
temperature=1.9,
top_p=0.9,
top_k=50,
repetition_penalty=1.3
)
st.success("Histoire générée avec succès!")
st.text_area("Histoire générée", value=generated_text, height=400)