|
import gradio as gr |
|
from collections import defaultdict |
|
import random |
|
import re |
|
import nltk |
|
from nltk.tokenize import word_tokenize |
|
from gensim.models import Word2Vec |
|
import numpy as np |
|
import itertools |
|
|
|
|
|
def import_corpus(file): |
|
with open(file.name, 'r', encoding='utf-8') as file: |
|
corpus = file.read() |
|
return corpus |
|
|
|
|
|
def preprocess_data(corpus): |
|
words = word_tokenize(corpus) |
|
words = [word.lower() for word in words] |
|
data = [] |
|
i = 0 |
|
while i < len(words) - 1: |
|
if words[i] == '"': |
|
dialogue = [words[i]] |
|
i += 1 |
|
while i < len(words) - 1 and words[i] != 'β': |
|
dialogue.append(words[i]) |
|
i += 1 |
|
dialogue.append(words[i]) |
|
data.append(tuple(dialogue)) |
|
else: |
|
data.append((words[i], words[i + 1])) |
|
i += 1 |
|
return data |
|
|
|
|
|
def train_word2vec(corpus): |
|
tokenized_corpus = [word_tokenize(sentence) for sentence in nltk.sent_tokenize(corpus)] |
|
model = Word2Vec(sentences=tokenized_corpus, vector_size=100, window=5, min_count=1, workers=4, sg=1) |
|
return model |
|
|
|
|
|
def train_model(data): |
|
model = defaultdict(lambda: defaultdict(int)) |
|
for word1, word2 in data: |
|
model[word1][word2] += 1 |
|
|
|
for word1 in model: |
|
total_count = float(sum(model[word1].values())) |
|
for word2 in model[word1]: |
|
model[word1][word2] /= total_count |
|
return model |
|
|
|
def identify_repetitive_phrases(generated_text, min_phrase_length=2, max_phrase_length=5, threshold=0.7): |
|
words = word_tokenize(generated_text) |
|
phrases = [] |
|
|
|
for phrase_length in range(min_phrase_length, max_phrase_length + 1): |
|
for i in range(len(words) - phrase_length + 1): |
|
phrase = ' '.join(words[i:i + phrase_length]) |
|
|
|
if phrase not in phrases: |
|
similarity_scores = [calculate_similarity(phrase, existing_phrase) for existing_phrase in phrases] |
|
if similarity_scores: |
|
pass |
|
|
|
if similarity_scores and max(similarity_scores) >= threshold: |
|
pass |
|
|
|
pass |
|
|
|
pass |
|
|
|
return phrases |
|
|
|
def calculate_similarity(phrase1, phrase2): |
|
tokens1 = phrase1.split() |
|
tokens2 = phrase2.split() |
|
intersection = len(set(tokens1) & set(tokens2)) |
|
union = len(set(tokens1) | set(tokens2)) |
|
return intersection / union if union > 0 else 0 |
|
|
|
def replace_repetitive_phrases(generated_text, word2vec_model): |
|
repetitive_phrases = identify_repetitive_phrases(generated_text) |
|
replaced_text = generated_text |
|
|
|
for phrase in repetitive_phrases: |
|
phrase_words = phrase.split() |
|
replacement = find_alternative_phrase(phrase_words, word2vec_model) |
|
if replacement: |
|
replaced_text = replaced_text.replace(phrase, replacement) |
|
else: |
|
pass |
|
|
|
return replaced_text |
|
|
|
def find_alternative_phrase(words, word2vec_model): |
|
alternative_phrases = [] |
|
for word in words: |
|
if word in word2vec_model.wv: |
|
similar_words_with_scores = word2vec_model.wv.most_similar(word) |
|
similar_words = [word for word, _ in similar_words_with_scores] |
|
alternative_phrases.append(similar_words) |
|
else: |
|
pass |
|
|
|
alternative_phrases_combinations = [' '.join(combination) for combination in itertools.product(*alternative_phrases)] |
|
|
|
highest_similarity = -1 |
|
best_alternative_phrase = None |
|
for alternative_phrase in alternative_phrases_combinations: |
|
similarity = calculate_phrase_similarity(words, alternative_phrase, word2vec_model) |
|
if np.any(similarity > highest_similarity): |
|
highest_similarity = similarity |
|
best_alternative_phrase = alternative_phrase |
|
|
|
return best_alternative_phrase |
|
|
|
def calculate_phrase_similarity(phrase1, phrase2, word2vec_model): |
|
phrase1_string = ' '.join(phrase1) |
|
phrase2_string = ' '.join(phrase2) |
|
phrase1_vector = np.mean([word2vec_model.wv[word] for word in phrase1_string.split() if word in word2vec_model.wv], axis=0) |
|
phrase2_vector = np.mean([word2vec_model.wv[word] for word in phrase2_string.split() if word in word2vec_model.wv], axis=0) |
|
|
|
if np.any(phrase1_vector) and np.any(phrase2_vector): |
|
similarity = np.dot(phrase1_vector, phrase2_vector) / (np.linalg.norm(phrase1_vector) * np.linalg.norm(phrase2_vector)) |
|
return similarity |
|
else: |
|
return 0.0 |
|
|
|
def evaluate_generated_text(generated_text): |
|
|
|
return "" |
|
|
|
def generate_sentence(model, start_word, length=101, context_window_size=4, max_context_window_size=100, blacklist=None, whitelist=None, whitelist_weight=0.1): |
|
print('======================================================================') |
|
print('========================== GENERATING SENTENCE ======================') |
|
print(f'Start word: {start_word}') |
|
print(f'Length: {length}') |
|
print(f'Context window size: {context_window_size}') |
|
print(f'Max context window size: {max_context_window_size}') |
|
print(f'Blacklist: {blacklist}') |
|
print(f'Whitelist: {whitelist}') |
|
print(f'Whitelist weight: {whitelist_weight}') |
|
print('======================================================================') |
|
|
|
|
|
if blacklist is None: |
|
print('Initializing blacklist to empty list') |
|
blacklist = [] |
|
|
|
sentence = [start_word] |
|
current_word = start_word |
|
repetitive_phrases = set() |
|
|
|
for i in range(length): |
|
print(f'Iteration {i+1}') |
|
print(f'Sentence: {sentence}') |
|
print(f'Current word: {current_word}') |
|
print(f'Context window size: {context_window_size}') |
|
print(f'Blacklist: {blacklist}') |
|
print(f'Whitelist: {whitelist}') |
|
|
|
if len(sentence) >= context_window_size and tuple(sentence[-context_window_size:]) in repetitive_phrases: |
|
print(f'Increasing context window size to: {context_window_size + 1}') |
|
context_window_size = min(context_window_size + 1, max_context_window_size) |
|
|
|
print(f'Next word candidates: {model[current_word].keys()}') |
|
next_word_candidates = [word for word in model[current_word].keys() if word not in blacklist] |
|
|
|
if whitelist: |
|
priority_words = [word for word in next_word_candidates if word in whitelist] |
|
if priority_words: |
|
print(f'Whitelist priority words: {priority_words}') |
|
if random.random() < whitelist_weight: |
|
next_word_candidates = priority_words |
|
else: |
|
next_word_candidates = [word for word in next_word_candidates if word not in whitelist] |
|
|
|
if not next_word_candidates: |
|
break |
|
|
|
next_word = random.choice(next_word_candidates) |
|
if next_word in blacklist: |
|
print(f'Removing {next_word} from blacklist') |
|
blacklist.remove(next_word) |
|
|
|
if next_word.startswith('β') and next_word.endswith('β'): |
|
sentence.append(next_word) |
|
else: |
|
sentence.append(next_word) |
|
|
|
current_word = next_word |
|
|
|
if len(sentence) >= context_window_size: |
|
repetitive_phrases.add(tuple(sentence[-context_window_size:])) |
|
|
|
generated_sentence = ' '.join(sentence) |
|
print(f'Generated sentence: {generated_sentence}') |
|
return generated_sentence |
|
|
|
def post_process_generated_text(generated_text): |
|
|
|
|
|
|
|
generated_text = re.sub(r'\s([?,.!"](?:\s|$))', r'\1', generated_text) |
|
|
|
|
|
generated_text = '. '.join(sentence.capitalize() for sentence in generated_text.split('. ')) |
|
|
|
|
|
generated_text = re.sub(r'([?.!"])\1+', r'\1', generated_text) |
|
|
|
|
|
generated_text = re.sub(r'\s([β])', r'\1', generated_text) |
|
|
|
|
|
generated_text = re.sub(r'([β])\s', r'\1', generated_text) |
|
|
|
return generated_text |
|
|
|
def generate_with_gradio(start_word, file, length=101, context_window_size=4, max_context_window_size=100, blacklist=None, whitelist=None, whitelist_weight=0.1): |
|
|
|
corpus = import_corpus(file) |
|
|
|
|
|
data = preprocess_data(corpus) |
|
|
|
|
|
language_model = train_model(data) |
|
|
|
|
|
word2vec_model = train_word2vec(corpus) |
|
|
|
|
|
generated_sentence = generate_sentence(language_model, start_word, length, context_window_size, max_context_window_size, blacklist=blacklist, whitelist=whitelist, whitelist_weight=whitelist_weight) |
|
|
|
|
|
replaced_sentence = replace_repetitive_phrases(generated_sentence, word2vec_model) |
|
|
|
|
|
processed_sentence = post_process_generated_text(replaced_sentence) |
|
|
|
return processed_sentence |
|
|
|
nltk.download('punkt') |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_with_gradio, |
|
inputs=[ |
|
"text", |
|
gr.File(label="Upload Corpus"), |
|
gr.Number(label="Length", value=101), |
|
gr.Number(label="Context Window Size", value=4), |
|
gr.Number(label="Max Context Window Size", value=100), |
|
gr.Textbox(label="Blacklist (comma-separated)"), |
|
gr.Textbox(label="Whitelist (comma-separated)"), |
|
gr.Number(label="Whitelist Weight", value=0.1) |
|
], |
|
outputs="text", |
|
title="Sentence Generator with Repetivecc", |
|
description="Enter a starting word and upload a corpus file to generate a sentence." |
|
) |
|
iface.launch() |
|
|