File size: 2,906 Bytes
b7be07b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from config import Config
from networks import peft_model


tokenizer = Config.tokenizer
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_tokens('<question-answer>')


def prepare_inputs(peft_model, audio_model, clip_model, projection, text_input=None, image_input=None, audio_input=None):

  text_audio, text_embed, image_embed = None, None, None

  if audio_input:
    audio_transcribed = audio_model.transcribe(audio_input)
    processed_audio = ''

    for audio_segment in audio_transcribed['segments']:
        processed_audio += audio_segment['text']

    processed_audio = processed_audio.strip()

  if image_input != None:
    image_processed = Config.processor(images=image_input, return_tensors="pt")

    with torch.no_grad():
      outputs = clip_model(**image_processed)
      last_hidden_state = outputs.last_hidden_state[:, 1:, :]
      image_embed = projection(last_hidden_state.to(Config.device)).to(torch.float16)

  if audio_input != None and text_input != None:
    text_audio = f"{text_input} {processed_audio}"
  elif audio_input and text_input == None:
    text_audio = processed_audio
  elif audio_input == None and text_input:
    text_audio = text_input

  if text_audio:
    tokenized_text_audio = tokenizer.encode(text_audio)
    tokenized_text_audio = Config.IMAGE_SEPARATOR_TOKENS + tokenized_text_audio + [Config.QUESTION_ANSWER_SEPARATOR_ID]

    with torch.no_grad():
      tokenized_text_audio = torch.tensor(tokenized_text_audio)
      text_embed = peft_model.model.model.embed_tokens(tokenized_text_audio.to(Config.device)).unsqueeze(0)


  if text_audio != None and image_input != None:
    combined_embed = torch.cat([image_embed, text_embed], dim=1)
  elif text_audio and image_input == None:
    combined_embed = text_embed
  elif text_audio == None and image_input:
    combined_embed = image_embed

  return(combined_embed)


def chatbot_response(text_input, image_input, audio_input):

    if text_input == '':
      text_input = None

    if text_input == None and image_input == None and audio_input == None:
        return "Please enter text, upload an image, or record audio."

    combined_embeds = prepare_inputs(text_input, image_input, audio_input)
    generated_tokens = generate_tokens(combined_embeds, max_tokens=60)
    return(tokenizer.decode(generated_tokens))



def generate_tokens(combined_embeds, max_tokens=100):
  pred_tokens = []

  combined_embed = combined_embeds

  for _ in range(max_tokens):
    logits = peft_model(inputs_embeds=combined_embed).logits[:, -1, :]
    next_token_id = logits.argmax(dim=-1)

    if next_token_id.item() == 50256:
        break

    pred_tokens.append(next_token_id.item())
    next_token_embed = peft_model.model.model.embed_tokens(next_token_id.unsqueeze(0))

    with torch.no_grad():
      combined_embed = torch.cat((combined_embed, next_token_embed), dim=1)

  return(pred_tokens)