hibalaz commited on
Commit
59654d2
β€’
1 Parent(s): 67de693

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sentence_transformers import SentenceTransformer, util
3
+ from transformers import pipeline, GPT2Tokenizer
4
+ import os
5
+
6
+ # Define paths and models
7
+ filename = "output_country_details.txt" # Adjust the filename as needed
8
+ retrieval_model_name = 'output/sentence-transformer-finetuned/'
9
+ gpt2_model_name = "gpt2" # GPT-2 model
10
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
11
+
12
+ # Load models
13
+ try:
14
+ retrieval_model = SentenceTransformer(retrieval_model_name)
15
+ gpt_model = pipeline("text-generation", model=gpt2_model_name)
16
+ print("Models loaded successfully.")
17
+ except Exception as e:
18
+ print(f"Failed to load models: {e}")
19
+
20
+ # Load and preprocess text from the country details file
21
+ def load_and_preprocess_text(filename):
22
+ try:
23
+ with open(filename, 'r', encoding='utf-8') as file:
24
+ segments = [line.strip() for line in file if line.strip()]
25
+ print("Text loaded and preprocessed successfully.")
26
+ return segments
27
+ except Exception as e:
28
+ print(f"Failed to load or preprocess text: {e}")
29
+ return []
30
+
31
+ segments = load_and_preprocess_text(filename)
32
+
33
+ def find_relevant_segment(user_query, segments):
34
+ try:
35
+ query_embedding = retrieval_model.encode(user_query)
36
+ segment_embeddings = retrieval_model.encode(segments)
37
+ similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0]
38
+ best_idx = similarities.argmax()
39
+ print("Relevant segment found:", segments[best_idx])
40
+ return segments[best_idx]
41
+ except Exception as e:
42
+ print(f"Error finding relevant segment: {e}")
43
+ return ""
44
+
45
+ def generate_response(user_query, relevant_segment):
46
+ try:
47
+ # Construct the prompt with the user query
48
+ prompt = f"Thank you for your question! this is an additional fact about your topic: {relevant_segment}"
49
+
50
+ # Generate response with adjusted max_length for completeness
51
+ max_tokens = len(tokenizer(prompt)['input_ids']) + 50
52
+ response = gpt_model(prompt, max_length=max_tokens, temperature=0.25)[0]['generated_text']
53
+
54
+ # Clean and format the response
55
+ response_cleaned = clean_up_response(response, relevant_segment)
56
+ return response_cleaned
57
+ except Exception as e:
58
+ print(f"Error generating response: {e}")
59
+ return ""
60
+
61
+ def clean_up_response(response, segments):
62
+ # Split the response into sentences
63
+ sentences = response.split('.')
64
+
65
+ # Remove empty sentences and any repetitive parts
66
+ cleaned_sentences = []
67
+ for sentence in sentences:
68
+ if sentence.strip() and sentence.strip() not in segments and sentence.strip() not in cleaned_sentences:
69
+ cleaned_sentences.append(sentence.strip())
70
+
71
+ # Join the sentences back together
72
+ cleaned_response = '. '.join(cleaned_sentences).strip()
73
+
74
+ # Check if the last sentence ends with a complete sentence
75
+ if cleaned_response and not cleaned_response.endswith((".", "!", "?")):
76
+ cleaned_response += "."
77
+
78
+ return cleaned_response
79
+
80
+ # Define the welcome message with markdown for formatting and larger fonts
81
+ welcome_message = """
82
+ # Welcome to VISABOT!
83
+
84
+ ## Your AI-driven visa assistant for all travel-related queries.
85
+
86
+ """
87
+
88
+ # Define topics and countries with flag emojis
89
+ topics = """
90
+ ### Feel Free to ask me anything from the topics below!
91
+ - Visa issuance
92
+ - Documents needed
93
+ - Application process
94
+ - Processing time
95
+ - Recommended Vaccines
96
+ - Health Risks
97
+ - Healthcare Facilities
98
+ - Currency Information
99
+ - Embassy Information
100
+ - Allowed stay
101
+ """
102
+
103
+ countries = """
104
+ ### Our chatbot can currently answer questions for these countries!
105
+ - πŸ‡¨πŸ‡³ China
106
+ - πŸ‡«πŸ‡· France
107
+ - πŸ‡¬πŸ‡Ή Guatemala
108
+ - πŸ‡±πŸ‡§ Lebanon
109
+ - πŸ‡²πŸ‡½ Mexico
110
+ - πŸ‡΅πŸ‡­ Philippines
111
+ - πŸ‡·πŸ‡Έ Serbia
112
+ - πŸ‡ΈπŸ‡± Sierra Leone
113
+ - πŸ‡ΏπŸ‡¦ South Africa
114
+ - πŸ‡»πŸ‡³ Vietnam
115
+ """
116
+
117
+ # Define the Gradio app interface
118
+ def query_model(question):
119
+ if question == "": # If there's no input, the bot will display the greeting message.
120
+ return welcome_message
121
+ relevant_segment = find_relevant_segment(question, segments)
122
+ response = generate_response(question, relevant_segment)
123
+ return response
124
+
125
+ # Create Gradio Blocks interface for custom layout
126
+ with gr.Blocks() as demo:
127
+ gr.Markdown(welcome_message) # Display the welcome message with large fonts
128
+ with gr.Row():
129
+ with gr.Column():
130
+ gr.Markdown(topics) # Display the topics on the left
131
+ with gr.Column():
132
+ gr.Markdown(countries) # Display the countries with flag emojis on the right
133
+ with gr.Row():
134
+ img = gr.Image(os.path.join(os.getcwd(), "final.png"), width=500) # Adjust width as needed
135
+ with gr.Row():
136
+ with gr.Column():
137
+ question = gr.Textbox(label="Your question", placeholder="What do you want to ask about?")
138
+ answer = gr.Textbox(label="VisaBot Response", placeholder="VisaBot will respond here...", interactive=False, lines=10)
139
+ submit_button = gr.Button("Submit")
140
+ submit_button.click(fn=query_model, inputs=question, outputs=answer)
141
+
142
+ # Launch the app
143
+ demo.launch()