Repetivec / app.py
EpGuy's picture
Update app.py
135b971 verified
import gradio as gr
from collections import defaultdict
import random
import re
import nltk
from nltk.tokenize import word_tokenize
from gensim.models import Word2Vec
import numpy as np
import itertools
# Step 1: Data Collection
def import_corpus(file):
with open(file.name, 'r', encoding='utf-8') as file:
corpus = file.read()
return corpus
# Step 2: Data Preprocessing using NLTK
def preprocess_data(corpus):
words = word_tokenize(corpus)
words = [word.lower() for word in words]
data = []
i = 0
while i < len(words) - 1:
if words[i] == '"':
dialogue = [words[i]]
i += 1
while i < len(words) - 1 and words[i] != 'β€œ':
dialogue.append(words[i])
i += 1
dialogue.append(words[i])
data.append(tuple(dialogue))
else:
data.append((words[i], words[i + 1]))
i += 1
return data
# Step 3: Model Training - Word2Vec
def train_word2vec(corpus):
tokenized_corpus = [word_tokenize(sentence) for sentence in nltk.sent_tokenize(corpus)]
model = Word2Vec(sentences=tokenized_corpus, vector_size=100, window=5, min_count=1, workers=4, sg=1) # Set sg=1 for Skip-gram
return model
# Step 4: Model Training - Language Model
def train_model(data):
model = defaultdict(lambda: defaultdict(int))
for word1, word2 in data:
model[word1][word2] += 1
for word1 in model:
total_count = float(sum(model[word1].values()))
for word2 in model[word1]:
model[word1][word2] /= total_count
return model
def identify_repetitive_phrases(generated_text, min_phrase_length=2, max_phrase_length=5, threshold=0.7):
words = word_tokenize(generated_text)
phrases = []
for phrase_length in range(min_phrase_length, max_phrase_length + 1):
for i in range(len(words) - phrase_length + 1):
phrase = ' '.join(words[i:i + phrase_length])
if phrase not in phrases:
similarity_scores = [calculate_similarity(phrase, existing_phrase) for existing_phrase in phrases]
if similarity_scores:
pass
if similarity_scores and max(similarity_scores) >= threshold:
pass # Skip if similarity with existing phrase is high
pass
pass
return phrases
def calculate_similarity(phrase1, phrase2):
tokens1 = phrase1.split()
tokens2 = phrase2.split()
intersection = len(set(tokens1) & set(tokens2))
union = len(set(tokens1) | set(tokens2))
return intersection / union if union > 0 else 0
def replace_repetitive_phrases(generated_text, word2vec_model):
repetitive_phrases = identify_repetitive_phrases(generated_text)
replaced_text = generated_text
for phrase in repetitive_phrases:
phrase_words = phrase.split()
replacement = find_alternative_phrase(phrase_words, word2vec_model)
if replacement:
replaced_text = replaced_text.replace(phrase, replacement)
else:
pass
return replaced_text
def find_alternative_phrase(words, word2vec_model):
alternative_phrases = []
for word in words:
if word in word2vec_model.wv:
similar_words_with_scores = word2vec_model.wv.most_similar(word)
similar_words = [word for word, _ in similar_words_with_scores]
alternative_phrases.append(similar_words)
else:
pass
alternative_phrases_combinations = [' '.join(combination) for combination in itertools.product(*alternative_phrases)]
highest_similarity = -1
best_alternative_phrase = None
for alternative_phrase in alternative_phrases_combinations:
similarity = calculate_phrase_similarity(words, alternative_phrase, word2vec_model)
if np.any(similarity > highest_similarity): # Check if any element is greater
highest_similarity = similarity
best_alternative_phrase = alternative_phrase
return best_alternative_phrase
def calculate_phrase_similarity(phrase1, phrase2, word2vec_model):
phrase1_string = ' '.join(phrase1)
phrase2_string = ' '.join(phrase2)
phrase1_vector = np.mean([word2vec_model.wv[word] for word in phrase1_string.split() if word in word2vec_model.wv], axis=0)
phrase2_vector = np.mean([word2vec_model.wv[word] for word in phrase2_string.split() if word in word2vec_model.wv], axis=0)
if np.any(phrase1_vector) and np.any(phrase2_vector):
similarity = np.dot(phrase1_vector, phrase2_vector) / (np.linalg.norm(phrase1_vector) * np.linalg.norm(phrase2_vector))
return similarity
else:
return 0.0
def evaluate_generated_text(generated_text):
# Implement evaluation logic (like how many phrases were replaced, etc.)
return ""
def generate_sentence(model, start_word, length=101, context_window_size=4, max_context_window_size=100, blacklist=None, whitelist=None, whitelist_weight=0.1):
print('======================================================================')
print('========================== GENERATING SENTENCE ======================')
print(f'Start word: {start_word}')
print(f'Length: {length}')
print(f'Context window size: {context_window_size}')
print(f'Max context window size: {max_context_window_size}')
print(f'Blacklist: {blacklist}')
print(f'Whitelist: {whitelist}')
print(f'Whitelist weight: {whitelist_weight}')
print('======================================================================')
# Initialize blacklist to an empty list if not provided
if blacklist is None:
print('Initializing blacklist to empty list')
blacklist = []
sentence = [start_word]
current_word = start_word
repetitive_phrases = set()
for i in range(length):
print(f'Iteration {i+1}')
print(f'Sentence: {sentence}')
print(f'Current word: {current_word}')
print(f'Context window size: {context_window_size}')
print(f'Blacklist: {blacklist}')
print(f'Whitelist: {whitelist}')
if len(sentence) >= context_window_size and tuple(sentence[-context_window_size:]) in repetitive_phrases:
print(f'Increasing context window size to: {context_window_size + 1}')
context_window_size = min(context_window_size + 1, max_context_window_size)
print(f'Next word candidates: {model[current_word].keys()}')
next_word_candidates = [word for word in model[current_word].keys() if word not in blacklist]
if whitelist:
priority_words = [word for word in next_word_candidates if word in whitelist]
if priority_words:
print(f'Whitelist priority words: {priority_words}')
if random.random() < whitelist_weight:
next_word_candidates = priority_words
else:
next_word_candidates = [word for word in next_word_candidates if word not in whitelist]
if not next_word_candidates:
break
next_word = random.choice(next_word_candidates)
if next_word in blacklist:
print(f'Removing {next_word} from blacklist')
blacklist.remove(next_word)
if next_word.startswith('β€œ') and next_word.endswith('”'):
sentence.append(next_word)
else:
sentence.append(next_word)
current_word = next_word
if len(sentence) >= context_window_size:
repetitive_phrases.add(tuple(sentence[-context_window_size:]))
generated_sentence = ' '.join(sentence)
print(f'Generated sentence: {generated_sentence}')
return generated_sentence
def post_process_generated_text(generated_text):
# Perform post-processing steps to improve readability, coherence, and grammar
# Correct spacing around punctuation marks
generated_text = re.sub(r'\s([?,.!"](?:\s|$))', r'\1', generated_text)
# Capitalize the first letter of each sentence
generated_text = '. '.join(sentence.capitalize() for sentence in generated_text.split('. '))
# Correct repeated punctuation
generated_text = re.sub(r'([?.!"])\1+', r'\1', generated_text)
# Remove space after right double quotation mark (”)
generated_text = re.sub(r'\s([”])', r'\1', generated_text)
# Remove space before left double quotation mark (β€œ)
generated_text = re.sub(r'([β€œ])\s', r'\1', generated_text)
return generated_text
def generate_with_gradio(start_word, file, length=101, context_window_size=4, max_context_window_size=100, blacklist=None, whitelist=None, whitelist_weight=0.1):
# Load the corpus from the uploaded file
corpus = import_corpus(file)
# Preprocess the data
data = preprocess_data(corpus)
# Train the language model
language_model = train_model(data)
# Train the Word2Vec model
word2vec_model = train_word2vec(corpus)
# Generate the sentence
generated_sentence = generate_sentence(language_model, start_word, length, context_window_size, max_context_window_size, blacklist=blacklist, whitelist=whitelist, whitelist_weight=whitelist_weight)
# Replace repetitive phrases
replaced_sentence = replace_repetitive_phrases(generated_sentence, word2vec_model)
# Post-process the generated sentence
processed_sentence = post_process_generated_text(replaced_sentence)
return processed_sentence
nltk.download('punkt')
# Create a Gradio interface with file uploader
iface = gr.Interface(
fn=generate_with_gradio,
inputs=[
"text", # Start Word
gr.File(label="Upload Corpus"), # Corpus File
gr.Number(label="Length", value=101), # Length
gr.Number(label="Context Window Size", value=4), # Context Window Size
gr.Number(label="Max Context Window Size", value=100), # Max Context Window Size
gr.Textbox(label="Blacklist (comma-separated)"), # Blacklist
gr.Textbox(label="Whitelist (comma-separated)"), # Whitelist
gr.Number(label="Whitelist Weight", value=0.1) # Whitelist Weight
],
outputs="text",
title="Sentence Generator with Repetivecc",
description="Enter a starting word and upload a corpus file to generate a sentence."
)
iface.launch()