File size: 2,552 Bytes
f483191
 
 
 
 
 
 
 
e4aeeee
48bf45d
b791812
5e32016
03b857c
1f1b415
f74f4fd
 
03b857c
 
f483191
55ea56c
 
f483191
 
 
 
 
555d0f1
f483191
 
 
 
 
 
55ea56c
58c29f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322ff78
 
 
 
 
 
 
55ea56c
f483191
 
55ea56c
f483191
 
 
1f1b415
3d68800
 
f483191
55ea56c
f483191
 
 
33c6f10
b791812
f483191
 
 
b2c635f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 21 22:17:43 2023

@author: Loges
"""

import streamlit as st
import sentencepiece
from gtts import gTTS
import base64
import time
from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration

model=T5ForConditionalGeneration.from_pretrained("Logeswaransr/T5_MineAI_Prototype").to("cpu")
tokenizer=T5Tokenizer.from_pretrained("Logeswaransr/T5_MineAI_Prototype")

pipe=pipeline('text2text-generation', model=model, tokenizer=tokenizer)

greetings=["Hello! My name is MineAI, A specially trained LLM here to assist you on your Mining Related Queries.","How may I help you?"]

st.set_page_config(page_title='Sample Chatbot', layout='wide')

if 'messages' not in st.session_state:
    st.session_state.messages=[]

st.subheader("Mine AI")

for message in st.session_state.messages:
    with st.chat_message(message['role']):
        st.markdown(message['content'])
        
## messages element format: {'role':'user', 'content':'<user prompt>'}

if st.session_state.messages==[]:
    for gr in greetings:
        with st.chat_message("assistant"):
            st.markdown(gr)

        if gr==greetings[1]:
            tts=gTTS(gr)
            tts.save('greeting_audio.mp3')
            with open('greeting_audio.mp3', 'rb') as file:
                greeting_audio_data=file.read()
            greeting_audio_base64 = base64.b64encode(greeting_audio_data).decode('utf-8')
            greeting_audio_tag = f'<audio autoplay="true" src="data:audio/mp3;base64,{greeting_audio_base64}">'
            st.markdown(greeting_audio_tag, unsafe_allow_html=True)
    
        st.session_state.messages.append({
            'role':'assistant',
            'content': gr})

audio_stream="response_audio.mp3"
tts=gTTS("Here is your answer")
tts.save(audio_stream)
with open(audio_stream, 'rb') as file:
    audio_data=file.read()
audio_base64 = base64.b64encode(audio_data).decode('utf-8')
audio_tag = f'<audio autoplay="true" src="data:audio/mp3;base64,{audio_base64}">'

if prompt:=st.chat_input("Enter your query"):
    with st.chat_message("user"):
        st.markdown(prompt)

    st.session_state.messages.append({
        'role':'user',
        'content': prompt})

    out=pipe(prompt)
    response=out[0]['generated_text']
    
    # response = f"Analysis: {response}"
    
    with st.chat_message("assistant"):
        st.markdown(response)

    st.markdown(audio_tag, unsafe_allow_html=True)
        
    st.session_state.messages.append({
        'role':'assistant',
        'content': response})