Technical Articles

Boost Your Replicate Models with Pruna AI: A Step-by-Step Guide

Jan 21, 2025

John Rachwan

John Rachwan

CTO

What is Replicate?

Deep learning models are getting bigger and more powerful, but making them run efficiently can be tricky. For running machine learning models in production, Replicate comes is a great platform. But even with its awesome capabilities, unoptimized models can rack up costs, slow down inference, and waste resources.

That’s why we’re excited to introduce you to Pruna AI—a smart, easy-to-use package designed to make your models faster, leaner, and more efficient.

In this blog, we will walk you through how to use Pruna to optimize your models and deploy them on Replicate. Whether you’re a seasoned ML engineer or just starting out, this guide will help you boost performance of your models. This is based on the following GitHub repository published by Pruna AI.

Why Optimize Your Models?

When it comes to deploying machine learning models, optimization is key to ensuring scalability and cost-effectiveness. Here’s why you should care:

  • Faster Inference Times: Optimized models run quicker, leading to better user experiences and higher user retention as a result.

  • Smaller Models: Optimized models are smaller, which allows you to either use a smaller GPU to run the same model or to run multiple models on the same GPU.

  • Lower Computational Costs: Faster inference times and smaller models allow you to save costs by paying less for the same number of inferences handles and switching to smaller GPUs that cost less, respectively.

  • Environmental Impact: Smaller and faster models consume less energy, making AI more sustainable.

Traditionally, model optimization required deep technical knowledge and significant effort. Pruna simplifies this by automating advanced techniques like quantization, pruning, caching, compilation, and more, tailored to your specific use case.

Getting Started with Pruna

Step 0: Folder Structure

This is the structure we will be following in this blog:

Key Components

  1. flux-schnell

    1. cog.yaml: Configures the model for reproducible builds on replicate and runs using Cog.

    2. predict.py: Implements the Predictor, handling loading, smashing, and inference.

  2. .github/workflows/

    1. push_flux_schnell.yaml: Automates model deployment to Replicate via GitHub Actions.

As you can see, each model has their own cog.yaml and predict.py scripts.

Step 1: Install Pruna

To use Pruna with Replicate, you’ll need Python ≥3.9 and any Nvidia GPU from Replicate. Replicate uses the cog framework for containerizing and deploying models. To integrate Pruna, you’ll need to update your cog.yaml file. Here’s an example configuration:

build:
  gpu: true
  cuda: "12.1"
  system_packages:
    - "libgl1-mesa-glx"
    - "libglib2.0-0"
    - "git"
    - "build-essential"
  python_version: "3.11"
  run:
    - command: pip install pruna[gpu]==0.1.2 --extra-index-url https://prunaai.pythonanywhere.com/
    - command: pip install colorama
    - command: export CC=/usr/bin/gcc
predict: "predict.py:Predictor"

This setup ensures that Pruna is available during the build process and integrates seamlessly with your model code.

Step 2: Optimize Your Model

Pruna offers powerful optimization techniques in an elementary and easy to use way. You can check out the Pruna documentation for a developer-friendly guide on how to use Pruna. In this blog, we will show you an example using the Flux Schnell model that is mostly based on this notebook.

1. Load Your Model

Start by loading your model using FluxPipeline. This will serve as the baseline model before optimization.

from diffusers import FluxPipeline
import torch

# Load the model
self.pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=torch.bfloat16,
    token="your_hugging_face_token"
).to("cuda")

2. Configure Pruna for Optimization

Define a SmashConfig object that specifies how the model should be optimized. Pruna allows you to customize parameters like caching and compilation.

from pruna import SmashConfig, smash

# Configure Pruna Smash
smash_config = SmashConfig()
smash_config['compilers'] = ['flux_caching']
smash_config['comp_flux_caching_cache_interval'] = 2
smash_config['comp_flux_caching_start_step'] = 0
smash_config['comp_flux_caching_compile'] = True
smash_config['comp_flux_caching_save_model'] = False

3. Compress your Model

Pass your model and configuration to Pruna’s smash function, which applies the compression. You will be asked to create your Pruna token automatically if you do not provide one.

# Optimize the model
self.pipe = smash(
    model=self.pipe,
    token="your_pruna_token",  # Replace with your token
    smash_config=smash_config,
)

4. Use the Compressed Model

After compression, the model is ready for prediction.

# Generate output
image = self.pipe(
    prompt="Your prompt here",
    height=1024,
    width=1024,
    guidance_scale=7.5,
    num_inference_steps=4,
).images[0]

Full Code Example

Below is the complete code for the predict.py script implementation, combining all the steps above:

import tempfile
import torch
from cog import BasePredictor, Input, Path
from diffusers import FluxPipeline
from pruna import SmashConfig, smash


class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load and optimize the model"""
        # Load the model
        self.pipe = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell",
            torch_dtype=torch.bfloat16,
            token="your_hugging_face_token"
        ).to("cuda")

        # Configure Pruna
        smash_config = SmashConfig()
        smash_config['compilers'] = ['flux_caching']
        smash_config['comp_flux_caching_cache_interval'] = 2
        smash_config['comp_flux_caching_start_step'] = 0
        smash_config['comp_flux_caching_compile'] = True
        smash_config['comp_flux_caching_save_model'] = False

        # Optimize the model
        self.pipe = smash(
            model=self.pipe,
            token="your_pruna_token",
            smash_config=smash_config,
        )

    def predict(
        self,
        prompt: str = Input(description="Prompt"),
        num_inference_steps: int = Input(
            description="Number of inference steps", default=4
        ),
        guidance_scale: float = Input(
            description="Guidance scale", default=7.5
        ),
        seed: int = Input(description="Seed", default=42),
        image_height: int = Input(description="Image height", default=1024),
        image_width: int = Input(description="Image width", default=1024),
        cache_interval: int = Input(description="Cache interval", default=3),
        start_step: int = Input(description="Start step", default=1),
    ) -> Path:
        """Run a prediction"""
        self.pipe.flux_cache_helper.set_params(
            cache_interval=cache_interval, start_step=start_step
        )

        image = self.pipe(
            prompt,
            height=image_height,
            width=image_width,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=torch.Generator("cpu").manual_seed(seed)
        ).images[0]

        output_dir = Path(tempfile.mkdtemp())
        image_path = output_dir / "output.png"
        image.save(image_path)
        return image_path

Don’t forget to replace the your_hugging_face_token and your_pruna_token by your Hugging Face and Pruna tokens respectively.

That’s it! Your model is now fast, leaner and ready for deployment.

Deploying the Optimized Model on Replicate

We are now ready to deploy our model to replicate. This can be done either manually in the command line or through a GitHub workflow. We will explain both in this section:

Option 1: Setup GitHub Workflow

To streamline the deployment, you can use GitHub Actions to automate pushing models to Replicate. Here’s an example workflow file (push_flux_schnell.yaml):

name: Push Flux Schnell to Replicate

on:
  workflow_dispatch:
    inputs:
      model_name:
        default: "prunaai/flux-schnell"
jobs:
  push_to_replicate:
    name: Push to Replicate
    runs-on: ubuntu-latest
    steps:
      - name: Free disk space
        uses: jlumbroso/free-disk-space@v1.3.1
      - name: Checkout
        uses: actions/checkout@v4
      - name: Setup Cog
        uses: replicate/setup-cog@v2
        with:
          token: ${{ secrets.REPLICATE_API_TOKEN }}
      - name: Push to Replicate
        run: |
          cog push

This workflow for continuous deployment of your optimized model, saving time and effort. This is convenient if you want to update the model in terms of which SmashConfig to use, which input parameters to expect, or even to add new features like LoRA support.

Option 2: Pushing to Replicate Manually

If you need to push the model manually, navigate to the flux-schnell directory and run:

cog push r8.im/prunaai/flux-schnell

Attention: Ensure that your Replicate API token is set either in GitHub Secrets or in your local environment. You can follow this Replicate guide to do this.

🎉Congratulations, you deployed your optimized model to Replicate! 🎉

When you combine Pruna’s optimization superpowers with Replicate’s seamless deployment platform, you get the best of both worlds—performance, cost savings, and scalability all in one!

In just a few steps, you can make your models run faster, use fewer resources, and deliver even better results. Curious to see it in action? Check out the models we’ve deployed using Pruna and Replicate on our Pruna Replicate repository.

Excited to give it a shot? Download Pruna, integrate it into your workflow, and deploy your optimized models on Replicate. If you have any questions, our team is here to support you! Join the conversation and ask us anything on Discord.

Button

Button

John Rachwan

John Rachwan

Jan 21, 2025

John Rachwan

John Rachwan

Jan 21, 2025

John Rachwan

John Rachwan

Jan 21, 2025

Speed Up Your Models With Pruna AI.

Inefficient models drive up costs, slow down your productivity and increase carbon emissions. Make your AI more accessible and sustainable with Pruna.

Speed Up Your Models With Pruna AI.

Inefficient models drive up costs, slow down your productivity and increase carbon emissions. Make your AI more accessible and sustainable with Pruna.

Speed Up Your Models With Pruna AI.

Inefficient models drive up costs, slow down your productivity and increase carbon emissions. Make your AI more accessible and sustainable with Pruna.

© 2024 Pruna AI - Built with Pretzels & Croissants 🥨 🥐

© 2024 Pruna AI - Built with Pretzels & Croissants

© 2024 Pruna AI - Built with Pretzels & Croissants 🥨 🥐