Skip to main content
The optimize() function from fal.toolkit takes a PyTorch module and applies dynamic compilation and quantization techniques to make it run faster on fal’s GPU infrastructure. It handles the complexity of selecting the right optimization strategy for your model and hardware, so you get improved inference speed without manually configuring torch.compile, quantization, or kernel tuning. This is most useful for diffusion models and other large PyTorch models where inference latency matters. Call optimize() once in your setup() method after loading the model, and all subsequent inference calls use the optimized version. The API is currently experimental and may change in future releases.
This API is currently experimental and may be subject to change.
from fal.toolkit import optimize

model = optimize(model)

Example

import fal
from fal.toolkit import Image, optimize
from pydantic import BaseModel, Field


class Input(BaseModel):
    prompt: str = Field(
        description="The prompt to generate an image from.",
        examples=[
            "A cinematic shot of a baby racoon wearing an intricate italian priest robe.",
        ],
    )


class Output(BaseModel):
    image: Image = Field(
        description="The generated image.",
    )


class FalModel(fal.App):
    machine_type = "GPU"
    requirements = [
        "accelerate",
        "transformers>=4.30.2",
        "diffusers>=0.26",
        "torch>=2.2.0",
    ]

    def setup(self) -> None:
        import torch
        from diffusers import AutoPipelineForText2Image

        # Load SDXL
        self.pipeline = AutoPipelineForText2Image.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            torch_dtype=torch.float16,
            variant="fp16",
        )
        self.pipeline.to("cuda")

        # Apply fal's spatial optimizer to the pipeline.
        self.pipeline.unet = optimize(self.pipeline.unet)
        self.pipeline.vae = optimize(self.pipeline.vae)

        # Warm up the model.
        self.pipeline(
            prompt="a cat",
            num_inference_steps=30,
        )

    @fal.endpoint("/")
    def text_to_image(self, input: Input) -> Output:
        result = self.pipeline(
            prompt=input.prompt,
            num_inference_steps=30,
        )
        [image] = result.images
        return Output(image=Image.from_pil(image))