import easygraph as eg
import pytest
import torch


def test_embedding_reg():
    print("EmbeddingRegularization" in eg.__dir__())
    emb_reg = eg.EmbeddingRegularization(p=2, weight_decay=1e-4)
    embs = [torch.randn(10, 3), torch.randn(10, 3)]
    loss = emb_reg(*embs)
    true_loss = 0
    for emb in embs:
        true_loss += 1 / 2 * emb.norm(2).pow(2) / 10
    assert loss.item() == pytest.approx(1e-4 * true_loss.item())
