Spaces:
Runtime error
Runtime error
import spaces | |
import soundfile as sf | |
import torch | |
from datetime import datetime | |
import random | |
import time | |
from datetime import datetime | |
import whisper | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, VitsModel | |
import torch | |
import numpy as np | |
import os | |
import argparse | |
import gradio as gr | |
from timeit import default_timer as timer | |
import torch | |
import numpy as np | |
import pandas as pd | |
import whisper | |
# whisper_model = whisper.load_model("medium").to("cuda") | |
# tts_model = VitsModel.from_pretrained("facebook/mms-tts-pol") | |
# tts_model.to("cuda") | |
# print("TTS Loaded!") | |
def load_whisper(): | |
return whisper.load_model("medium", device = 'cpu') | |
def load_tts(): | |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-pol") | |
#tts_model.to("cuda") | |
tokenizer_tss = AutoTokenizer.from_pretrained("facebook/mms-tts-pol") | |
return tts_model, tokenizer_tss | |
def save_to_txt(text_to_save): | |
with open('prompt.txt', 'w', encoding='utf-8') as f: | |
f.write(text_to_save) | |
def read_txt(): | |
with open('prompt.txt') as f: | |
lines = f.readlines() | |
return lines | |
def _load_model_tokenizer(): | |
model_id = 'tangger/Qwen-7B-Chat' | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",trust_remote_code=True, fp16=True).eval() | |
return model, tokenizer | |
whisper_model = load_whisper() | |
tts_model, tokenizer_tss = load_tts() | |
model, tokenizer = _load_model_tokenizer() | |
def postprocess(self, y): | |
if y is None: | |
return [] | |
for i, (message, response) in enumerate(y): | |
y[i] = ( | |
None if message is None else mdtex2html.convert(message), | |
None if response is None else mdtex2html.convert(response), | |
) | |
return y | |
def _parse_text(text): | |
lines = text.split("\n") | |
lines = [line for line in lines if line != ""] | |
count = 0 | |
for i, line in enumerate(lines): | |
if "```" in line: | |
count += 1 | |
items = line.split("`") | |
if count % 2 == 1: | |
lines[i] = f'<pre><code class="language-{items[-1]}">' | |
else: | |
lines[i] = f"<br></code></pre>" | |
else: | |
if i > 0: | |
if count % 2 == 1: | |
line = line.replace("`", r"\`") | |
line = line.replace("<", "<") | |
line = line.replace(">", ">") | |
line = line.replace(" ", " ") | |
line = line.replace("*", "*") | |
line = line.replace("_", "_") | |
line = line.replace("-", "-") | |
line = line.replace(".", ".") | |
line = line.replace("!", "!") | |
line = line.replace("(", "(") | |
line = line.replace(")", ")") | |
line = line.replace("$", "$") | |
lines[i] = "<br>" + line | |
text = "".join(lines) | |
return text | |
def predict(_query, _chatbot, _task_history): | |
print(f"User: {_parse_text(_query)}") | |
_chatbot.append((_parse_text(_query), "")) | |
full_response = "" | |
for response in model.chat_stream(tokenizer, _query, history=_task_history,system = "Jesteś assystentem AI. Odpowiadaj zawsze w języku poslkim" ): | |
_chatbot[-1] = (_parse_text(_query), _parse_text(response)) | |
yield _chatbot | |
full_response = _parse_text(response) | |
print(f"History: {_task_history}") | |
_task_history.append((_query, full_response)) | |
print(f"Qwen-7B-Chat: {_parse_text(full_response)}") | |
def read_text(text): | |
print("___Tekst do przeczytania!") | |
inputs = tokenizer_tss(text, return_tensors="pt").to("cuda") | |
with torch.no_grad(): | |
output = tts_model(**inputs).waveform.squeeze().cpu().numpy() | |
sf.write('temp_file.wav', output, tts_model.config.sampling_rate) | |
return 'temp_file.wav' | |
def update_audio(text): | |
return 'temp_file.wav' | |
def translate(audio): | |
print("__Wysyłam nagranie do whisper!") | |
transcription = whisper_model.transcribe(audio, language="pl") | |
return transcription["text"] | |
def predict(audio, _chatbot, _task_history): | |
# Użyj funkcji translate, aby przekształcić audio w tekst | |
_query = translate(audio) | |
print(f"____User: {_parse_text(_query)}") | |
_chatbot.append((_parse_text(_query), "")) | |
full_response = "" | |
for response in model.chat_stream(tokenizer, | |
_query, | |
history= _task_history, | |
system = "Jesteś assystentem AI. Odpowiadaj zawsze w języku polskim. Odpowiadaj krótko."): | |
_chatbot[-1] = (_parse_text(_query), _parse_text(response)) | |
yield _chatbot | |
full_response = _parse_text(response) | |
print(f"____History: {_task_history}") | |
_task_history.append((_query, full_response)) | |
print(f"__Qwen-7B-Chat: {_parse_text(full_response)}") | |
print("____full_response",full_response) | |
audio_file = read_text(_parse_text(full_response)) # Generowanie audio | |
return full_response | |
def regenerate(_chatbot, _task_history): | |
if not _task_history: | |
yield _chatbot | |
return | |
item = _task_history.pop(-1) | |
_chatbot.pop(-1) | |
yield from predict(item[0], _chatbot, _task_history) | |
with gr.Blocks() as chat_demo: | |
chatbot = gr.Chatbot(label='Llama Voice Chatbot', elem_classes="control-height") | |
query = gr.Textbox(lines=2, label='Input') | |
task_history = gr.State([]) | |
audio_output = gr.Audio('temp_file.wav', label="Generated Audio (wav)", type='filepath', autoplay=False) | |
with gr.Row(): | |
submit_btn = gr.Button("🚀 Wyślij tekst") | |
with gr.Row(): | |
audio_upload = gr.Audio(source="microphone", type="filepath", show_label=False) | |
submit_audio_btn = gr.Button("🎙️ Wyślij audio") | |
submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True) | |
submit_audio_btn.click(predict, [audio_upload, chatbot, task_history], [chatbot], show_progress=True).then(update_audio, chatbot, audio_output) | |
chat_demo.queue().launch() |