I faced an error, when i call this function train_bert :
train_bert(net, criterion, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate)
The error is:
AttributeError: 'tuple' object has no attribute 'to'
The implementation of the train function is as follows:
def train_bert(net, criterion, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate):
best_loss = np.Inf
best_ep = 1
nb_iterations = len(train_loader)
print_every = nb_iterations // 5
iters = []
train_losses = []
val_losses = []
scaler = GradScaler()
for ep in range(epochs):
net.train()
running_loss = 0.0
for it, (seq, attn_masks, token_type_ids, labels) in enumerate(tqdm(train_loader)):
seq, attn_masks, token_type_ids, labels = \
seq.to(device), attn_masks.to(device), token_type_ids.to(device), labels.to(device)
with autocast():
logits = net(seq, attn_masks, token_type_ids)
loss = criterion(logits.squeeze(-1), labels.float())
loss = loss / iters_to_accumulate
scaler.scale(loss).backward()
if (it + 1) % iters_to_accumulate == 0:
scaler.step(opti)
scaler.update()
lr_scheduler.step()
opti.zero_grad()
running_loss += loss.item()
if (it + 1) % print_every == 0:
print()
print("Iteration {}/{} of epoch {} complete. Loss : {} "
.format(it+1, nb_iterations, ep+1, running_loss / print_every))
running_loss = 0.0
val_loss = evaluate_loss(net, device, criterion, val_loader)
print()
print("Epoch {} complete! Validation Loss : {}".format(ep+1, val_loss))
if val_loss < best_loss:
print("Best validation loss improved from {} to {}".format(best_loss, val_loss))
print()
net_copy = copy.deepcopy(net)
best_loss = val_loss
best_ep = ep + 1
path_to_model='models/{}_lr_{}_val_loss_{}_ep_{}.pt'.format(bert_model, lr, round(best_loss, 5), best_ep)
torch.save(net_copy.state_dict(), path_to_model)
print("The model has been saved in {}".format(path_to_model))
del loss
torch.cuda.empty_cache()
Please help me i struggle with this problem for weeks:(
What I have tried:
I tried lots of solutions, one of them is to encode the labels but didn't work.