|
import re |
|
import sqlite3 |
|
from flask import g |
|
from transformers import pipeline, set_seed |
|
|
|
def generate(Entered_story): |
|
|
|
|
|
if not Entered_story.strip(): |
|
raise ValueError("Empty input!") |
|
|
|
|
|
if not validate_story(Entered_story): |
|
raise ValueError("Incorrect format!") |
|
|
|
|
|
generator = pipeline('text-generation', model='gpt2') |
|
|
|
|
|
set_seed(42) |
|
generated_text = generator(Entered_story, max_length=30, num_return_sequences=5) |
|
|
|
generated_text = generated_text[0]['generated_text'] |
|
|
|
return generated_text |
|
|
|
|
|
def validate_story(Entered_story): |
|
pattern = r'As a (?P<role>[^,.]+), I want to (?P<goal>[^,.]+)(?:,|.)+\s*so that' |
|
match = re.search(pattern, Entered_story, flags=re.DOTALL) |
|
return bool(match) |
|
|
|
|
|
|
|
def getTextGenContents(): |
|
db = getattr(g, '_database', None) |
|
if db is None: |
|
db = g._database = sqlite3.connect('Refineverse.db') |
|
cursor = db.cursor() |
|
cursor.execute("SELECT userStory, generatedStory FROM TextGeneration") |
|
rows = cursor.fetchall() |
|
return rows |
|
|
|
|
|
|
|
def insertTextGenRow( Entered_story, generatedStory): |
|
with sqlite3.connect('Refineverse.db') as conn: |
|
cursor = conn.cursor() |
|
cursor.execute("INSERT INTO TextGeneration (userStory, generatedStory) VALUES (?, ?)", (Entered_story, generatedStory)) |
|
conn.commit() |
|
|