Bruno commited on
Commit
7573f5f
1 Parent(s): b43fcf2

Create apply.py

Browse files
Files changed (1) hide show
  1. apply.py +77 -0
apply.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from util import UIDataset, Vocabulary
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ from model import *
8
+ from torchvision import transforms
9
+ from PIL import Image
10
+
11
+ # Carrega o modelo treinado
12
+ net = Pix2Code()
13
+ net.load_state_dict(torch.load('./pix2code.weights'))
14
+ net.cuda().eval()
15
+
16
+ # Carrega o vocabulário
17
+ vocab = Vocabulary('voc.pkl')
18
+
19
+ # Define uma transformação para redimensionar e normalizar as imagens
20
+ transform = transforms.Compose([
21
+ transforms.Resize((256, 256)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
24
+ ])
25
+
26
+ # Função que receberá a imagem e retornará o código GUI gerado
27
+ def generate_gui(image):
28
+ # Aplica a transformação na imagem
29
+ image = transform(image).unsqueeze(0).cuda()
30
+
31
+ # Cria um contexto inicial
32
+ context = torch.tensor([vocab.to_vec(' '), vocab.to_vec('<START>')]).unsqueeze(0).float().cuda()
33
+
34
+ # Inicializa uma lista para armazenar o código gerado
35
+ code = []
36
+
37
+ # Gera o código iterativamente até encontrar o token <END>
38
+ for i in range(200):
39
+ # Passa a imagem e o contexto para a rede neural e obtém o índice do token com maior probabilidade
40
+ index = torch.argmax(net(image, context), 2).squeeze()[-1:].squeeze()
41
+
42
+ # Converte o índice para o token correspondente
43
+ token = vocab.to_vocab(int(index))
44
+
45
+ # Se encontrar o token <END>, interrompe a geração do código
46
+ if token == '<END>':
47
+ break
48
+
49
+ # Adiciona o token à lista de código gerado
50
+ code.append(token)
51
+
52
+ # Atualiza o contexto com o token gerado
53
+ context = torch.cat([context, torch.tensor([vocab.to_vec(token)]).unsqueeze(0).float().cuda()], dim=1)
54
+
55
+ # Retorna o código gerado como uma string
56
+ return ''.join(code)
57
+
58
+ import gradio as gr
59
+
60
+ # Define o componente de entrada
61
+ image_input = gr.inputs.Image()
62
+
63
+ # Define o componente de saída
64
+ text_output = gr.outputs.Textbox()
65
+
66
+ # Cria a interface Gradio
67
+ iface = gr.Interface(
68
+ fn=generate_gui,
69
+ inputs=image_input,
70
+ outputs=text_output,
71
+ title='Pix2Code',
72
+ description='Gerador de código GUI a partir de imagens',
73
+ theme='default'
74
+ )
75
+
76
+ # Executa a interface
77
+ iface.launch()