Jaane commited on
Commit
7fb2b0e
1 Parent(s): 2e4f657

created app

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import requests
6
+ import random
7
+ import warnings
8
+ from transformers import logging
9
+ import os
10
+ import tensorflow as tf
11
+
12
+ # Set environment configurations
13
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
14
+ tf.get_logger().setLevel('ERROR')
15
+ warnings.filterwarnings("ignore")
16
+ logging.set_verbosity_error()
17
+
18
+ GROQ_API_KEY = "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"
19
+
20
+ def segment_into_sentences_groq(passage):
21
+ headers = {
22
+ "Authorization": f"Bearer {GROQ_API_KEY}",
23
+ "Content-Type": "application/json"
24
+ }
25
+ payload = {
26
+ "model": "llama3-8b-8192",
27
+ "messages": [
28
+ {
29
+ "role": "system",
30
+ "content": "you are to segment the sentence by adding '1!2@3#' at the end of each sentence. Return only the segmented sentences only return the modified passage and nothing else do not add your responses"
31
+ },
32
+ {
33
+ "role": "user",
34
+ "content": f"you are to segment the sentence by adding '1!2@3#' at the end of each sentence. Return only the segmented sentences only return the modified passage and nothing else do not add your responses. here is the passage:{passage}"
35
+ }
36
+ ],
37
+ "temperature": 0.0,
38
+ "max_tokens": 8192
39
+ }
40
+ response = requests.post("https://api.groq.com/openai/v1/chat/completions", json=payload, headers=headers)
41
+ if response.status_code == 200:
42
+ data = response.json()
43
+ try:
44
+ segmented_text = data.get("choices", [{}])[0].get("message", {}).get("content", "")
45
+ sentences = segmented_text.split("1!2@3#")
46
+ return [sentence.strip() for sentence in sentences if sentence.strip()]
47
+ except (IndexError, KeyError):
48
+ raise ValueError("Unexpected response structure from Groq API.")
49
+ else:
50
+ raise ValueError(f"Groq API error: {response.text}")
51
+
52
+ class TextEnhancer:
53
+ def __init__(self):
54
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ self.paraphrase_tokenizer = AutoTokenizer.from_pretrained("prithivida/parrot_paraphraser_on_T5")
56
+ self.paraphrase_model = T5ForConditionalGeneration.from_pretrained("prithivida/parrot_paraphraser_on_T5").to(self.device)
57
+ self.grammar_pipeline = pipeline(
58
+ "text2text-generation",
59
+ model="Grammarly/coedit-large",
60
+ device=0 if self.device == "cuda" else -1
61
+ )
62
+ self.similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2').to(self.device)
63
+
64
+ def enhance_text(self, text, min_similarity=0.8, max_variations=3):
65
+ sentences = segment_into_sentences_groq(text)
66
+ enhanced_sentences = []
67
+
68
+ for sentence in sentences:
69
+ if not sentence.strip():
70
+ continue
71
+ inputs = self.paraphrase_tokenizer(
72
+ f"paraphrase: {sentence}",
73
+ return_tensors="pt",
74
+ padding=True,
75
+ max_length=150,
76
+ truncation=True
77
+ ).to(self.device)
78
+ outputs = self.paraphrase_model.generate(
79
+ **inputs,
80
+ max_length=len(sentence.split()) + 20,
81
+ num_return_sequences=max_variations,
82
+ num_beams=max_variations,
83
+ temperature=0.7
84
+ )
85
+ paraphrases = [
86
+ self.paraphrase_tokenizer.decode(output, skip_special_tokens=True)
87
+ for output in outputs
88
+ ]
89
+ sentence_embedding = self.similarity_model.encode(sentence)
90
+ paraphrase_embeddings = self.similarity_model.encode(paraphrases)
91
+ similarities = util.cos_sim(sentence_embedding, paraphrase_embeddings)
92
+ valid_paraphrases = [
93
+ para for para, sim in zip(paraphrases, similarities[0])
94
+ if sim >= min_similarity
95
+ ]
96
+ if valid_paraphrases:
97
+ corrected = self.grammar_pipeline(
98
+ valid_paraphrases[0],
99
+ max_length=150,
100
+ num_return_sequences=1
101
+ )[0]["generated_text"]
102
+ enhanced_sentences.append(corrected)
103
+ else:
104
+ enhanced_sentences.append(sentence)
105
+
106
+ enhanced_text = ". ".join(sentence.rstrip(".") for sentence in enhanced_sentences) + "."
107
+ return enhanced_text
108
+
109
+ def create_interface():
110
+ enhancer = TextEnhancer()
111
+
112
+ def process_text(text, similarity_threshold):
113
+ try:
114
+ return enhancer.enhance_text(
115
+ text,
116
+ min_similarity=similarity_threshold / 100
117
+ )
118
+ except Exception as e:
119
+ return f"Error: {str(e)}"
120
+
121
+ interface = gr.Interface(
122
+ fn=process_text,
123
+ inputs=[
124
+ gr.Textbox(label="Input Text", placeholder="Enter text to enhance...", lines=10),
125
+ gr.Slider(minimum=50, maximum=100, value=80, label="Minimum Semantic Similarity (%)")
126
+ ],
127
+ outputs=gr.Textbox(label="Enhanced Text", lines=10),
128
+ title="Text Enhancement System",
129
+ description="Improve text quality while preserving original meaning"
130
+ )
131
+
132
+ return interface
133
+
134
+ if __name__ == "__main__":
135
+ interface = create_interface()
136
+ interface.launch()