SOP_improved / app.py
Jaane's picture
Update app.py
d26463a verified
raw
history blame
5.07 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline
from sentence_transformers import SentenceTransformer, util
import requests
import os
import warnings
from transformers import logging
# Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore")
logging.set_verbosity_error()
# Set API keys and environment variables
GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Ensure you set this in Hugging Face Spaces
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Groq API sentence segmentation
def segment_into_sentences_groq(passage):
headers = {
"Authorization": f"Bearer {GROQ_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "llama3-8b-8192",
"messages": [
{
"role": "system",
"content": "Segment sentences by adding '1!2@3#' at the end of each sentence."
},
{
"role": "user",
"content": f"Segment the passage: {passage}"
}
],
"temperature": 1.0,
"max_tokens": 8192
}
response = requests.post("https://api.groq.com/openai/v1/chat/completions", json=payload, headers=headers)
if response.status_code == 200:
data = response.json()
segmented_text = data.get("choices", [{}])[0].get("message", {}).get("content", "")
sentences = segmented_text.split("1!2@3#")
return [sentence.strip() for sentence in sentences if sentence.strip()]
else:
raise ValueError(f"Groq API error: {response.text}")
# Text enhancement class
class TextEnhancer:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.paraphrase_tokenizer = AutoTokenizer.from_pretrained("prithivida/parrot_paraphraser_on_T5")
self.paraphrase_model = T5ForConditionalGeneration.from_pretrained("prithivida/parrot_paraphraser_on_T5").to(self.device)
self.grammar_pipeline = pipeline(
"text2text-generation",
model="Grammarly/coedit-large",
device=0 if self.device == "cuda" else -1
)
self.similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2').to(self.device)
def enhance_text(self, text, min_similarity=0.8, max_variations=2):
sentences = segment_into_sentences_groq(text)
enhanced_sentences = []
for sentence in sentences:
if not sentence.strip():
continue
# Generate paraphrases
inputs = self.paraphrase_tokenizer(
f"paraphrase: {sentence}",
return_tensors="pt",
padding=True,
max_length=150,
truncation=True
).to(self.device)
outputs = self.paraphrase_model.generate(
**inputs,
max_length=150,
num_return_sequences=max_variations,
num_beams=max_variations
)
paraphrases = [
self.paraphrase_tokenizer.decode(output, skip_special_tokens=True)
for output in outputs
]
# Calculate semantic similarity
sentence_embedding = self.similarity_model.encode(sentence)
paraphrase_embeddings = self.similarity_model.encode(paraphrases)
similarities = util.cos_sim(sentence_embedding, paraphrase_embeddings)
# Select the most similar paraphrase
valid_paraphrases = [
para for para, sim in zip(paraphrases, similarities[0])
if sim >= min_similarity
]
if valid_paraphrases:
corrected = self.grammar_pipeline(
valid_paraphrases[0],
max_length=150,
num_return_sequences=1
)[0]["generated_text"]
enhanced_sentences.append(corrected)
else:
enhanced_sentences.append(sentence)
return ". ".join(enhanced_sentences).strip() + "."
# Gradio interface
def create_interface():
enhancer = TextEnhancer()
def process_text(text, similarity_threshold):
try:
return enhancer.enhance_text(text, min_similarity=similarity_threshold / 100)
except Exception as e:
return f"Error: {str(e)}"
return gr.Interface(
fn=process_text,
inputs=[
gr.Textbox(lines=10, placeholder="Enter text to enhance...", label="Input Text"),
gr.Slider(50, 100, 80, label="Minimum Semantic Similarity (%)")
],
outputs=gr.Textbox(lines=10, label="Enhanced Text"),
title="Text Enhancement System",
description="Enhance text quality with semantic preservation."
)
if __name__ == "__main__":
interface = create_interface()
interface.launch(server_name="0.0.0.0", server_port=7860)