Metadata-Version: 2.1
Name: wassersteinwormhole
Version: 0.2.0
Summary: Transformer based embeddings for Wasserstein Distances
License: MIT
Author: Anon
Requires-Python: >=3.9,<4.0
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Requires-Dist: clu (==0.0.10)
Requires-Dist: flax (==0.7.5)
Requires-Dist: ott-jax (==0.4.4)
Requires-Dist: tqdm (>=4.66.1,<5.0.0)
Description-Content-Type: text/markdown

WassersteinWormhole for Python3
======================

Embedding point-clouds by presering Wasserstein distancse with the Wormhole.

This implementation is written in Python3 and relies on FLAX, JAX, & JAX-OTT.


To install JAX, simply run the command:

    pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

And to install WassersteinWormhole along with the rest of the requirements: 

    pip install wassersteinwormhole

And running the Womrhole on your own set of point-clouds is as simple as:
    
    from wassersteinwormhole import Wormhole 
    WormholeModel = Wormhole(point_clouds = point_clouds)
    WormholeModel.train()
    Embeddings = WormholeModel.encode(WormholeModel.point_clouds, WormholeModel.masks)
    
For more details, follow tutorial at https://github.com/dpeerlab/WassersteinWormhole.
    
