SefyanKehail commited on
Commit
d6783ca
·
1 Parent(s): 012f016

epochs test

Browse files
Files changed (1) hide show
  1. app.py +45 -7
app.py CHANGED
@@ -1,14 +1,52 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' 🤔
 
 
 
 
 
 
 
7
 
8
  @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' 🤗
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from torchvision import datasets, transforms
7
+ from torch.utils.data import DataLoader
8
 
9
+ class SimpleNet(nn.Module):
10
+ def __init__(self):
11
+ super(SimpleNet, self).__init__()
12
+ self.fc = nn.Linear(784, 10) # Simple model for MNIST
13
+
14
+ def forward(self, x):
15
+ x = x.view(-1, 784) # Flatten the image
16
+ x = self.fc(x)
17
+ return x
18
 
19
  @spaces.GPU
20
+ def train_model(epochs):
21
+ # Load MNIST dataset
22
+ transform = transforms.Compose([transforms.ToTensor()])
23
+ train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
24
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
25
+
26
+ # Model, loss, and optimizer
27
+ model = SimpleNet().cuda()
28
+ criterion = nn.CrossEntropyLoss()
29
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
30
+
31
+ # Training loop
32
+ for epoch in range(epochs):
33
+ model.train()
34
+ running_loss = 0.0
35
+ for data, target in train_loader:
36
+ data, target = data.cuda(), target.cuda()
37
+ optimizer.zero_grad()
38
+ output = model(data)
39
+ loss = criterion(output, target)
40
+ loss.backward()
41
+ optimizer.step()
42
+ running_loss += loss.item()
43
+
44
+ print(f"Epoch {epoch + 1}, Average Loss: {running_loss / len(train_loader)}")
45
+
46
+ # Save the model checkpoint
47
+ torch.save(model.state_dict(), "simple_net.pth")
48
+ return "Training completed and model saved."
49
 
50
+ # Define the Gradio interface
51
+ demo = gr.Interface(fn=train_model, inputs=gr.Slider(1, 5, step=1, default=1, label="Number of Epochs"), outputs="text")
52
+ demo.launch()