Introduction
The landscape of artificial intelligence and machine learning continues to evolve rapidly, with significant advancements in generative AI models. One such notable development comes from Black Forest Labs with their FLUX.1 suite of models. These models push the boundaries of text-to-image synthesis, offering unparalleled image detail, prompt adherence, and style diversity. In this blog, we will delve into the process of fine-tuning the FLUX model using Dreambooth, a method that has gained traction for its effectiveness in producing high-quality, customized AI-generated content.
Understanding Flux.1 Model Family
Black Forest Labs has introduced three variants of the FLUX.1 model:
- FLUX.1 [pro]: The premium offering with top-notch image generation capabilities for non-commercial use.
- FLUX.1 [dev]: An open-weight, guidance-distilled model for non-commercial use, providing efficient performance.
- FLUX.1 [schnell]: Designed for local development and personal use, available under an Apache 2.0 license.
Learn more about this from official announcement here
These models are based on a hybrid architecture of multimodal and parallel diffusion transformer blocks, scaled to 12 billion parameters. They offer state-of-the-art performance, surpassing other leading models.
What is Dreambooth?
Dreambooth is a technique for fine-tuning generative models with a small dataset to produce highly customized outputs. It leverages the existing capabilities of pre-trained models and enhances them with specific details, styles, or subjects provided in the fine-tuning dataset. This method is particularly useful for applications requiring personalized content generation.
Prerequisites
Before we proceed with fine-tuning the LUX.1 [schnell] model using Dreambooth, ensure you have the following:
- Access to the FLUX.1 [schnell] model, which can be found on HuggingFace.
- A dataset containing images and corresponding text descriptions for fine-tuning.
- A computing environment with adequate resources (e.g., GPUs) to handle the training process.
Steps to Finetune Flux Using Dreambooth
In this blog, we’ll utilize Azure Machine Learning to fine-tune a text-to-image model to generate pictures of dogs based on textual input.
Before we begin, ensure you have the following:
- An Azure account with access to Azure Machine Learning.
- A basic understanding of Python and Jupyter notebooks.
- Familiarity with Hugging Face’s Diffusers library.
Step 1: Set Up the Environment
First, set up your environment by installing the necessary libraries. You can use the following commands:
pip install transformers diffusers accelerate
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
Step 2: Load the libraries
Load the libraries
import sys
sys.path.insert(0, '..')
import os
import shutil
import random
from azure.ai.ml import automl, Input, Output, MLClient, command, load_job
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.ai.ml.entities import Data, Environment, AmlCompute
from azure.ai.ml.constants import AssetTypes
from azure.core.exceptions import ResourceNotFoundError
import matplotlib.pyplot as plt
import mlflow
from mlflow.tracking.client import MlflowClient
Before we dive in the code, you’ll need to connect to your workspace. The workspace is the top-level resource for Azure Machine Learning, providing a centralized place to work with all the artifacts you create when you use Azure Machine Learning.
We are using DefaultAzureCredential
to get access to workspace. DefaultAzureCredential
should be capable of handling most scenarios. If you want to learn more about other available credentials, go to set up authentication doc, azure-identity reference doc.
Replace AML_WORKSPACE_NAME
, RESOURCE_GROUP
and SUBSCRIPTION_ID
with their respective values in the below cell.
from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential
credential = DefaultAzureCredential()
workspace_ml_client = None
try:
workspace_ml_client = MLClient.from_config(credential)
subscription_id = workspace_ml_client.subscription_id
resource_group = workspace_ml_client.resource_group_name
workspace_name = workspace_ml_client.workspace_name
except Exception as ex:
print(ex)
# Enter details of your AML workspace
subscription_id = "SUBSCRIPTION_ID"
resource_group = "RESOURCE_GROUP"
workspace_name = "AML_WORKSPACE_NAME"
workspace_ml_client = MLClient(
credential, subscription_id, resource_group, workspace_name
)
registry_ml_client = MLClient(
credential,
subscription_id,
resource_group,
registry_name="azureml",
)
workspace = workspace_ml_client.workspace_name
subscription_id = workspace_ml_client.workspaces.get(workspace).id.split("/")[2]
resource_group = workspace_ml_client.workspaces.get(workspace).resource_group
local_train_data = './train-data/monu/' # Azure ML dataset will be created for training on this content
generated_images = './results/monu'
azureml_dataset_name = 'monu' # Name of the dataset
train_target = 'gpu-cluster-big'
experiment_name = 'dreambooth-finetuning'
training_env_name = 'dreambooth-flux-train-envn'
inference_env_name = 'flux-inference-envn'
Step 3: Prepare the Dataset
Prepare your dataset by organizing images and their descriptions. Ensure that the data is in a format compatible with Dreambooth. Here’s an example structure:
train-data/monu/
image_1.jpg
image_2.jpg
...
3.1 Upload the images to Datastore through an AML Data asset (URI Folder)
In order to use the data for training in Azure ML, we upload it to our default Azure Blob Storage of our Azure ML Workspace.
# Register dataset
my_data = Data(
path= local_train_data,
type= AssetTypes.URI_FOLDER,
description= "Training images for Dreambooth finetuning",
name= azureml_dataset_name
)
workspace_ml_client.data.create_or_update(my_data)
Step 4: Create the Training Environment
We will require a dreambooth-conda.yaml file to create our customer environment.
name: dreambooth-flux-env
channels:
- conda-forge
dependencies:
- python=3.10
- pip:
- 'git+https://github.com/huggingface/diffusers.git'
- transformers>=4.41.2
- azureml-acft-accelerator==0.0.59
- azureml_acft_common_components==0.0.59
- azureml-acft-contrib-hf-nlp==0.0.59
- azureml-evaluate-mlflow==0.0.59
- azureml-metrics[text]==0.0.59
- mltable==1.6.1
- mpi4py==3.1.5
- sentencepiece==0.1.99
- transformers==4.44.0
- datasets==2.17.1
- optimum==1.17.1
- accelerate>=0.31.0
- onnxruntime==1.17.3
- rouge-score==0.1.2
- sacrebleu==2.4.0
- bitsandbytes==0.43.3
- einops==0.7.0
- aiohttp==3.10.5
- peft==0.8.2
- deepspeed==0.15.0
- trl==0.8.1
- tiktoken==0.6.0
- scipy==1.14.0
environment = Environment(
image="mcr.microsoft.com/azureml/curated/acft-hf-nlp-gpu:67",
conda_file="environment/dreambooth-conda.yaml",
name=training_env_name,
description="Dreambooth training environment",
)
workspace_ml_client.environments.create_or_update(environment)
Step 5: Create the Compute
In order to finetune a model on Azure Machine Learning studio, you will need to create a compute resource first. Creating a compute will take 3–4 minutes.
For additional references, see Azure Machine Learning in a Day.
try:
_ = workspace_ml_client.compute.get(train_target)
print("Found existing compute target.")
except ResourceNotFoundError:
print("Creating a new compute target...")
compute_config = AmlCompute(
name=train_target,
type="amlcompute",
size="Standard_NC24ads_A100_v4", # 1 x A100, 80 GB GPU memory each
tier="low_priority",
idle_time_before_scale_down=600,
min_instances=0,
max_instances=2,
)
workspace_ml_client.begin_create_or_update(compute_config)
Step 6: Create the Compute
We will use the black-forest-labs/FLUX.1-schnell
model in this notebook. By following this guide, you have successfully fine-tuned a text-to-image model using Diffusers and Dreambooth on Azure. This model can generate high-quality images of dogs based on textual descriptions, showcasing the power and flexibility of combining these advanced techniques. Feel free to experiment with different prompts and fine-tuning parameters to further explore the capabilities of your model.
First Lets create a command line instruction
command_str = '''python prepare.py && accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-schnell" \
--instance_data_dir=${{inputs.input_data}} \
--output_dir="outputs/models" \
--mixed_precision="bf16" \
--instance_prompt="photo of sks dog" \
--class_prompt="photo of a dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=2500 \
--seed="0"'''
As you can see you require 2 files to run the above command line train_dreambooth_lora_flux.py and prepare.py. You can download the train_dreambooth_lora_flux.py from the official diffusers repo here.
Here is the code for prepare.py
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:100"
from accelerate.utils import write_basic_config
write_basic_config()
Your folder structure should be like this
src/
prepare.py
train_dreambooth_lora_flux.py
Now lets initialize some variables.
# Retrieve latest version of dataset
latest_version = [dataset.latest_version for dataset in workspace_ml_client.data.list() if dataset.name == azureml_dataset_name][0]
dataset_asset = workspace_ml_client.data.get(name= azureml_dataset_name, version= latest_version)
print(f'Latest version of {azureml_dataset_name}: {latest_version}')
inputs = {"input_data": Input(type=AssetTypes.URI_FOLDER, path=f'azureml:{azureml_dataset_name}:{latest_version}')}
outputs = {"output_dir": Output(type=AssetTypes.URI_FOLDER)}
In this case we submit a job with the above code, compute and environment created.
job = command(
inputs = inputs,
outputs = outputs,
code = "./src",
command = command_str,
environment = f"{training_env_name}:latest",
compute = train_target,
experiment_name = experiment_name,
display_name= "flux-finetune-batchsize-1",
environment_variables = {'HF_TOKEN': 'Place Your HF Token Here'}
)
returned_job = workspace_ml_client.jobs.create_or_update(job)
returned_job
Step 7: Download the Finetuned Model and Register the Model
After fine-tuning, evaluate the model to ensure it meets your requirements.
We will register the model from the output of the fine tuning job. This will track lineage between the fine tuned model and the fine tuning job. The fine tuning job, further, tracks lineage to the foundation model, data and training code.
# Obtain the tracking URL from MLClient
MLFLOW_TRACKING_URI = workspace_ml_client.workspaces.get(name=workspace_ml_client.workspace_name).mlflow_tracking_uri
# Set the MLFLOW TRACKING URI
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
# Initialize MLFlow client
mlflow_client = MlflowClient()
mlflow_run = mlflow_client.get_run(returned_job.name)
mlflow.artifacts.download_artifacts(run_id= mlflow_run.info.run_id,
artifact_path="outputs/models/", # Azure ML job output
dst_path="./train-artifacts") # local folder
Now lets download the model
json_path = "./train-artifacts/outputs/models/pytorch_lora_weights.safetensors"
if os.path.isdir("./train-artifacts/outputs/models/pytorch_lora_weights.safetensors"):
shutil.rmtree(json_path)
mlflow.artifacts.download_artifacts(run_id= mlflow_run.info.run_id,
artifact_path="outputs/models/pytorch_lora_weights.safetensors", # Azure ML job output
dst_path="./train-artifacts") # local folder
Finally lets register the model.
from azure.ai.ml.entities import Model
from azure.ai.ml.constants import AssetTypes
run_model = Model(
path=f"azureml://jobs/{returned_job.name}/outputs/artifacts/paths/outputs/models/pytorch_lora_weights.safetensors",
name="mano-dreambooth-flux-finetuned",
description="Model created from run.",
type=AssetTypes.CUSTOM_MODEL,
)
model = workspace_ml_client.models.create_or_update(run_model)
Step 8: Online Managed Endpoint Deployment
Now lets deploy this finetuned model as a Online Managed Endpoint on AML. First lets define a some constant variables to be used later while deploying.
endpoint_name = 'flux-endpoint-finetuned-a100'
deployment_name = 'flux'
instance_type = 'Standard_NC24ads_A100_v4
score_file = 'score.py'
Lets create a Managed online endpoint.
# create an online endpoint
endpoint = ManagedOnlineEndpoint(
name=endpoint_name,
description="this is the flux inference online endpoint",
auth_mode="key"
)
workspace_ml_client.online_endpoints.begin_create_or_update(endpoint)
Step 9: Create Inference Environment for Online Endpoint
First we create a Dockerfile which will be used during creation of our environment.
FROM mcr.microsoft.com/aifx/acpt/stable-ubuntu2004-cu121-py310-torch22x:biweekly.202408.3
# Install pip dependencies
COPY requirements.txt .
RUN pip install -r requirements.txt --no-cache-dir
# Inference requirements
COPY --from=mcr.microsoft.com/azureml/o16n-base/python-assets:20230419.v1 /artifacts /var/
RUN /var/requirements/install_system_requirements.sh && \
cp /var/configuration/rsyslog.conf /etc/rsyslog.conf && \
cp /var/configuration/nginx.conf /etc/nginx/sites-available/app && \
ln -sf /etc/nginx/sites-available/app /etc/nginx/sites-enabled/app && \
rm -f /etc/nginx/sites-enabled/default
ENV SVDIR=/var/runit
ENV WORKER_TIMEOUT=400
EXPOSE 5001 8883 8888
# support Deepspeed launcher requirement of passwordless ssh login
RUN apt-get update
RUN apt-get install -y openssh-server openssh-client
requirements.txt part of this Dockefile as follows.
azureml-core==1.57.0
azureml-dataset-runtime==1.57.0
azureml-defaults==1.57.0
azure-ml==0.0.1
azure-ml-component==0.9.18.post2
azureml-mlflow==1.57.0
azureml-contrib-services==1.57.0
azureml-contrib-services==1.57.0
torch-tb-profiler~=0.4.0
azureml-inference-server-http
inference-schema
MarkupSafe==2.1.2
regex
pybind11
urllib3>=1.26.18
cryptography>=42.0.4
aiohttp>=3.8.5
py-spy==0.3.12
debugpy~=1.6.3
ipykernel~=6.0
tensorboard
psutil~=5.8.0
matplotlib~=3.5.0
tqdm~=4.66.3
py-cpuinfo==5.0.0
torch-tb-profiler~=0.4.0
transformers==4.44.2
diffusers==0.30.1
accelerate>=0.31.0
sentencepiece
peft
bitsandbytes
Make sure the folder structure has to be in the following format
inference-env/python-and-pip
Dockerfile
requirements.txt
Finally lets run the below code to create our inference environment for Flux LORA Model
env_docker_context = Environment(
build=BuildContext(path="docker-contexts/python-and-pip"),
name=inference_env_name,
description="Environment created from a Docker context.",
)
ml_client.environments.create_or_update(env_docker_context)
Step 10: Create Deployment for the Managed Online Endpoint
Finally lets deploy the model to the endpoint we created. Lets create a file called score.py and put it under a folder called assets.
assets/
score.py
import torch
import io
import os
import logging
import json
import math
import numpy as np
from base64 import b64encode
import requests
from PIL import Image, ImageDraw
from safetensors.torch import load_file
from azureml.contrib.services.aml_response import AMLResponse
from transformers import pipeline
from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline
from diffusers import AutoPipelineForText2Image, FluxPipeline
from diffusers.schedulers import EulerAncestralDiscreteScheduler
from diffusers import DPMSolverMultistepScheduler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def init():
"""
This function is called when the container is initialized/started, typically after create/update of the deployment.
You can write the logic here to perform init operations like caching the model in memory
"""
global pipe, refiner
weights_path = os.path.join(
os.getenv("AZUREML_MODEL_DIR"), "pytorch_lora_weights.safetensors"
)
print("weights_path:", weights_path)
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
pipe.load_lora_weights(weights_path, use_safetensors=True)
pipe.to(device)
# refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-refiner-1.0",
# torch_dtype=torch.float16,
# use_safetensors=True,
# variant="fp16"
# )
# refiner.to(device)
logging.info("Init complete")
def get_image_object(image_url):
"""
This function takes an image URL and returns an Image object.
"""
response = requests.get(image_url)
init_image = Image.open(io.BytesIO(response.content).convert("RGB"))
return init_image
def prepare_response(images):
"""
This function takes a list of images and converts them to a dictionary of base64 encoded strings.
"""
ENCODING = 'utf-8'
dic_response = {}
for i, image in enumerate(images):
output = io.BytesIO()
image.save(output, format="JPEG")
base64_bytes = b64encode(output.getvalue())
base64_string = base64_bytes.decode(ENCODING)
dic_response[f'image_{i}'] = base64_string
return dic_response
def design(prompt, image=None, num_images_per_prompt=4, negative_prompt=None, strength=0.65, guidance_scale=7.5, num_inference_steps=50, seed=None, design_type='TXT_TO_IMG', mask=None, other_args=None):
"""
This function takes various parameters like prompt, image, seed, design_type, etc., and generates images based on the specified design type. It returns a list of generated images.
"""
generator = None
if seed:
generator = torch.manual_seed(seed)
else:
generator = torch.manual_seed(0)
print('other_args', other_args)
image = pipe(prompt=prompt,
height=512,
width=768,
guidance_scale=guidance_scale,
output_type="pil",
generator=generator).images[0]
#image = refiner(prompt=prompt, image=image[None, :], generator=generator).images[0]
return [image]
def run(raw_data):
"""
This function takes raw data as input, processes it, and calls the design function to generate images.
It then prepares the response and returns it.
"""
logging.info("Request received")
print(f'raw data: {raw_data}')
data = json.loads(raw_data)["data"]
print(f'data: {data}')
prompt = data['prompt']
negative_prompt = data['negative_prompt']
seed = data['seed']
num_images_per_prompt = data['num_images_per_prompt']
guidance_scale = data['guidance_scale']
num_inference_steps = data['num_inference_steps']
design_type = data['design_type']
image_url = None
mask_url = None
mask = None
other_args = None
image = None
strength = data['strength']
if 'mask_image' in data:
mask_url = data['mask_image']
mask = get_image_object(mask_url)
if 'other_args' in data:
other_args = data['other_args']
if 'image_url' in data:
image_url = data['image_url']
image = get_image_object(image_url)
if 'strength' in data:
strength = data['strength']
with torch.inference_mode():
images = design(prompt=prompt, image=image,
num_images_per_prompt=num_images_per_prompt,
negative_prompt=negative_prompt, strength=strength,
guidance_scale=guidance_scale, num_inference_steps=num_inference_steps,
seed=seed, design_type=design_type, mask=mask, other_args=other_args)
preped_response = prepare_response(images)
resp = AMLResponse(message=preped_response, status_code=200, json_str=True)
return resp
Finally we can go ahead and deploy this .
deployment = ManagedOnlineDeployment(
name=deployment_name,
endpoint_name=endpoint_name,
model=model,
environment=env_docker_context,
code_configuration=CodeConfiguration(
code="assets", scoring_script=score_file
),
instance_type=instance_type,
instance_count=1,
request_settings=OnlineRequestSettings(request_timeout_ms=90000, max_queue_wait_ms=900000, max_concurrent_requests_per_instance=5),
liveness_probe=ProbeSettings(
failure_threshold=30,
success_threshold=1,
timeout=2,
period=10,
initial_delay=1000,
),
readiness_probe=ProbeSettings(
failure_threshold=10,
success_threshold=1,
timeout=10,
period=10,
initial_delay=1000,
),
environment_variables = {'HF_TOKEN': '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'},
)
workspace_ml_client.online_deployments.begin_create_or_update(deployment).result()
Step 11: Test the Deployment
Finally we are good to test this endpoint.
# Create request json
import json
request_json = {
"input_data": {
"columns": ["prompt"],
"index": [0],
"data": ["a photo of sks dog in a bucket"],
},
"params": {
"height": 512,
"width": 512,
"num_inference_steps": 50,
"guidance_scale": 7.5,
"negative_prompt": ["blurry; three legs"],
"num_images_per_prompt": 2,
},
}
request_file_name = "sample_request_data.json"
with open(request_file_name, "w") as request_file:
json.dump(request_json, request_file)
responses = workspace_ml_client.online_endpoints.invoke(
endpoint_name=online_endpoint_name,
deployment_name=deployment_name,
request_file=request_file_name,
)
responses = json.loads(responses)
import base64
from io import BytesIO
from PIL import Image
for response in responses:
base64_string = response["generated_image"]
image_stream = BytesIO(base64.b64decode(base64_string))
image = Image.open(image_stream)
display(image)
Conclusion
Finetuning the FLUX model using Dreambooth is a powerful way to customize generative AI models for specific applications. By following the steps outlined in this blog, you can leverage the strengths of the FLUX.1 [dev] model and enhance it with your unique dataset, achieving high-quality, personalized outputs. Whether you're working on creative projects, research, or commercial applications, this approach offers a robust solution for advancing your AI capabilities