Spaces:
Build error
Build error
import torch | |
from config import Config | |
from networks import peft_model | |
tokenizer = Config.tokenizer | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.add_tokens('<question-answer>') | |
peft_model, audio_model = Config.peft_model, Config.audio_model | |
clip_model, projection = Config.clip_model, Config.projection | |
def prepare_inputs(text_input=None, image_input=None, audio_input=None): | |
text_audio, text_embed, image_embed = None, None, None | |
if audio_input: | |
audio_transcribed = audio_model.transcribe(audio_input) | |
processed_audio = '' | |
for audio_segment in audio_transcribed['segments']: | |
processed_audio += audio_segment['text'] | |
processed_audio = processed_audio.strip() | |
if image_input != None: | |
image_processed = Config.processor(images=image_input, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = clip_model(**image_processed.to(Config.device)) | |
last_hidden_state = outputs.last_hidden_state[:, 1:, :] | |
image_embed = projection(last_hidden_state.to(Config.device)).to(torch.float16) | |
if audio_input != None and text_input != None: | |
text_audio = f"{text_input} {processed_audio}" | |
elif audio_input and text_input == None: | |
text_audio = processed_audio | |
elif audio_input == None and text_input: | |
text_audio = text_input | |
if text_audio: | |
tokenized_text_audio = tokenizer.encode(text_audio) | |
tokenized_text_audio = Config.IMAGE_SEPARATOR_TOKENS + tokenized_text_audio + [Config.QUESTION_ANSWER_SEPARATOR_ID] | |
with torch.no_grad(): | |
tokenized_text_audio = torch.tensor(tokenized_text_audio) | |
text_embed = peft_model.model.model.embed_tokens(tokenized_text_audio.to(Config.device)).unsqueeze(0) | |
if text_audio != None and image_input != None: | |
combined_embed = torch.cat([image_embed, text_embed], dim=1) | |
elif text_audio and image_input == None: | |
combined_embed = text_embed | |
elif text_audio == None and image_input: | |
combined_embed = image_embed | |
return(combined_embed) | |
def chatbot_response(text_input, image_input, audio_input): | |
if text_input == '': | |
text_input = None | |
if text_input == None and image_input == None and audio_input == None: | |
return "Please enter text, upload an image, or record audio." | |
combined_embeds = prepare_inputs(text_input, image_input, audio_input) | |
generated_tokens = generate_tokens(combined_embeds, max_tokens=60) | |
return(tokenizer.decode(generated_tokens)) | |
def generate_tokens(combined_embeds, max_tokens=100): | |
pred_tokens = [] | |
combined_embed = combined_embeds | |
for _ in range(max_tokens): | |
logits = peft_model(inputs_embeds=combined_embed).logits[:, -1, :] | |
next_token_id = logits.argmax(dim=-1) | |
if next_token_id.item() == 50256: | |
break | |
pred_tokens.append(next_token_id.item()) | |
next_token_embed = peft_model.model.model.embed_tokens(next_token_id.unsqueeze(0)) | |
with torch.no_grad(): | |
combined_embed = torch.cat((combined_embed, next_token_embed), dim=1) | |
return(pred_tokens) |