Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content

Triton kernels for linear RNN (currently recurrentgemma)

License

Notifications You must be signed in to change notification settings

TushaarGVS/linear-rnn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Linear RNN (Triton)

Installation

The repository uses Poetry to manage dependencies. To install dependencies for the entire package, run:

conda create -n linear-rnn-env python=3.10
conda activate linear-rnn-env
# conda install -c "nvidia/label/cuda-12.4.0" cuda-toolkit

cd $HOME; git clone https://github.com/TushaarGVS/linear-rnn.git
cd $HOME/linear-rnn

To install and run the package in an editable mode, run:

pip install -e .

Profiling (and running)

Sample tests are included within the source code files (please follow those to view how to use the provided modules). To profile, simply run (all options here):

ncu_path -f -o ~/profile_log.txt python3 linear_rnn/triton/sequential_scan_diag_a.py

On Linux, the ncu_path is defaulted to: /usr/local/cuda-<version>/nsight-compute-<version>/ncu; for other platforms, please refer to the installation doc.

About

Triton kernels for linear RNN (currently recurrentgemma)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published