Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to main content

Get the Reddit app

Scan this QR code to download the app now
Or check it out in the app stores
r/StableDiffusion icon
r/StableDiffusion icon
Go to StableDiffusion
r/StableDiffusion
A banner for the subreddit

/r/StableDiffusion is back open after the protest of Reddit killing open API access, which will bankrupt app developers, hamper moderation, and exclude blind users from the site. More info: https://rtech.support/docs/meta/blackout.html#what-is-going-on Discord: https://discord.gg/4WbTj8YskM Check out our new Lemmy instance: https://lemmy.dbzer0.com/c/stable_diffusion


Members Online

The Gory Details of Finetuning SDXL for 30M samples

Tutorial - Guide

There's lots of details on how to train SDXL loras, but details on how the big SDXL finetunes were trained is scarce to say the least. I recently released a big SDXL finetune. 1.5M images, 30M training samples, 5 days on an 8xH100. So, I'm sharing all the training details here to help the community.

Finetuning SDXL

bigASP was trained on about 1,440,000 photos, all with resolutions larger than their respective aspect ratio bucket. Each image is about 1MB on disk, making the dataset about 1TB per million images.

Every image goes through: a quality model to rate it from 0 to 9; JoyTag to tag it; OWLv2 with the prompt "a watermark" to detect watermarks in the images. I found OWLv2 to perform better than even a finetuned vision model, and it has the added benefit of providing bounding boxes for the watermarks. Accuracy is about 92%. While it wasn't done for this version, it's possible in the future that the bounding boxes could be used to do "loss masking" during training, which basically hides the watermarks from SD. For now, if a watermark is detect, a "watermark" tag is included in the training prompt.

Images with a score of 0 are dropped entirely. I did a lot of work specifically training the scoring model to put certain images down in this score bracket. You'd be surprised at how much junk comes through in datasets, and even a hint of them can really throw off training. Thumbnails, video preview images, ads, etc.

bigASP uses the same aspect ratios buckets that SDXL's paper defines. All images are bucketed into the bucket they best fit in while not being smaller than any dimension of that bucket when scaled down. So after scaling, images get randomly cropped. The original resolution and crop data is recorded alongside the VAE encoded image on disk for conditioning SDXL, and finally the latent is gzipped. I found gzip to provide a nice 30% space savings. This reduces the training dataset down to about 100GB per million images.

Training was done using a custom training script based off the diffusers library. I used a custom training script so that I could fully understand all the inner mechanics and implement any tweaks I wanted. Plus I had my training scripts from SD1.5 training, so it wasn't a huge leap. The downside is that a lot of time had to be spent debugging subtle issues that cropped up after several bugged runs. Those are all expensive mistakes. But, for me, mistakes are the cost of learning.

I think the training prompts are really important to the performance of the final model in actual usage. The custom Dataset class is responsible for doing a lot of heavy lifting when it comes to generating the training prompts. People prompt with everything from short prompts to long prompts, to prompts with all kinds of commas, underscores, typos, etc.

I pulled a large sample of AI images that included prompts to analyze the statistics of typical user prompts. The distribution of prompt length followed a mostly normal distribution, with a mean of 32 tags and a std of 19.8. So my Dataset class reflects this. For every training sample, it picks a random integer in this distribution to determine how many tags it should use for this training sample. It shuffles the tags on the image and then truncates them to that number.

This means that during training the model sees everything from just "1girl" to a huge 224 token prompt. And thus, hopefully, learns to fill in the details for the user.

Certain tags, like watermark, are given priority and always included if present, so the model learns those tags strongly. This also has the side effect of conditioning the model to not generate watermarks unless asked during inference.

The tag alias list from danbooru is used to randomly mutate tags to synonyms so that bigASP understands all the different ways people might refer to a concept. Hopefully.

And, of course, the score tags. Just like Pony XL, bigASP encodes the score of a training sample as a range of tags of the form "score_X" and "score_X_up". However, to avoid the issues Pony XL ran into (shoulders of giants), only a random number of score tags are included in the training prompt. It includes between 1 and 3 randomly selected score tags that are applicable to the image. That way the model doesn't require "score_8, score_7, score_6, score_5..." in the prompt to work correctly. It's already used to just a single, or a couple score tags being present.

10% of the time the prompt is dropped completely, being set to an empty string. UCG, you know the deal. N.B.!!! I noticed in Stability's training scripts, and even HuggingFace's scripts, that instead of setting the prompt to an empty string, they set it to "zero" in the embedded space. This is different from how SD1.5 was trained. And it's different from how most of the SD front-ends do inference on SD. My theory is that it can actually be a big problem if SDXL is trained with "zero" dropping instead of empty prompt dropping. That means that during inference, if you use an empty prompt, you're telling the model to move away not from the "average image", but away from only images that happened to have no caption during training. That doesn't sound right. So for bigASP I opt to train with empty prompt dropping.

Additionally, Stability's training scripts include dropping of SDXL's other conditionings: original_size, crop, and target_size. I didn't see this behavior present in kohyaa's scripts, so I didn't use it. I'm not entirely sure what benefit it would provide.

I made sure that during training, the model gets a variety of batched prompt lengths. What I mean is, the prompts themselves for each training sample are certainly different lengths, but they all have to be padded to the longest example in a batch. So it's important to ensure that the model still sees a variety of lengths even after batching, otherwise it might overfit to a specific range of prompt lengths. A quick Python Notebook to scan the training batches helped to verify a good distribution: 25% of batches were 225 tokens, 66% were 150, and 9% were 75 tokens. Though in future runs I might try to balance this more.

The rest of the training process is fairly standard. I found min-snr loss to work best in my experiments. Pure fp16 training did not work for me, so I had to resort to mixed precision with the model in fp32. Since the latents are already encoded, the VAE doesn't need to be loaded, saving precious memory. For generating sample images during training, I use a separate machine which grabs the saved checkpoints and generates the sample images. Again, that saves memory and compute on the training machine.

The final run uses an effective batch size of 2048, no EMA, no offset noise, PyTorch's AMP with just float16 (not bfloat16), 1e-4 learning rate, AdamW, min-snr loss, 0.1 weight decay, cosine annealing with linear warmup for 100,000 training samples, 10% UCG rate, text encoder 1 training is enabled, text encoded 2 is kept frozen, min_snr_gamma=5, PyTorch GradScaler with an initial scaling of 65k, 0.9 beta1, 0.999 beta2, 1e-8 eps. Everything is initialized from SDXL 1.0.

A validation dataset of 2048 images is used. Validation is performed every 50,000 samples to ensure that the model is not overfitting and to help guide hyperparameter selection. To help compare runs with different loss functions, validation is always performed with the basic loss function, even if training is using e.g. min-snr. And a checkpoint is saved every 500,000 samples. I find that it's really only helpful to look at sample images every million steps, so that process is run on every other checkpoint.

A stable training loss is also logged (I use Wandb to monitor my runs). Stable training loss is calculated at the same time as validation loss (one after the other). It's basically like a validation pass, except instead of using the validation dataset, it uses the first 2048 images from the training dataset, and uses a fixed seed. This provides a, well, stable training loss. SD's training loss is incredibly noisy, so this metric provides a much better gauge of how training loss is progressing.

The batch size I use is quite large compared to the few values I've seen online for finetuning runs. But it's informed by my experience with training other models. Large batch size wins in the long run, but is worse in the short run, so its efficacy can be challenging to measure on small scale benchmarks. Hopefully it was a win here. Full runs on SDXL are far too expensive for much experimentation here. But one immediate benefit of a large batch size is that iteration speed is faster, since optimization and gradient sync happens less frequently.

Training was done on an 8xH100 sxm5 machine rented in the cloud. On this machine, iteration speed is about 70 images/s. That means the whole run took about 5 solid days of computing. A staggering number for a hobbyist like me. Please send hugs. I hurt.

Training being done in the cloud was a big motivator for the use of precomputed latents. Takes me about an hour to get the data over to the machine to begin training. Theoretically the code could be set up to start training immediately, as the training data is streamed in for the first pass. It takes even the 8xH100 four hours to work through a million images, so data can be streamed faster than it's training. That way the machine isn't sitting idle burning money.

One disadvantage of precomputed latents is, of course, the lack of regularization from varying the latents between epochs. The model still sees a very large variety of prompts between epochs, but it won't see different crops of images or variations in VAE sampling. In future runs what I might do is have my local GPUs re-encoding the latents constantly and streaming those updated latents to the cloud machine. That way the latents change every few epochs. I didn't detect any overfitting on this run, so it might not be a big deal either way.

Finally, the loss curve. I noticed a rather large variance in the validation loss between different datasets, so it'll be hard for others to compare, but for what it's worth:

https://i.imgur.com/74VQYLS.png

Learnings and the Future

I had a lot of failed runs before this release, as mentioned earlier. Mostly bugs in the training script, like having the height and width swapped for the original_size, etc conditionings. Little details like that are not well documented, unfortunately. And a few runs to calibrate hyperparameters: trying different loss functions, optimizers, etc. Animagine's hyperparameters were the most well documented that I could find, so they were my starting point. Shout out to that team!

I didn't find any overfitting on this run, despite it being over 20 epochs of the data. That said, 30M training samples, as large as it is to me, pales in comparison to Pony XL which, as far as I understand, did roughly the same number of epochs just with 6M! images. So at least 6x the amount of training I poured into bigASP. Based on my testing of bigASP so far, it has nailed down prompt following and understands most of the tags I've thrown at it. But the undertraining is apparent in its inconsistency with overall image structure and having difficulty with more niche tags that occur less than 10k times in the training data. I would definitely expect those things to improve with more training.

Initially for encoding the latents I did "mixed-VAE" encoding. Basically, I load in several different VAEs: SDXL at fp32, SDXL at fp16, SDXL at bf16, and the fp16-fix VAE. Then each image is encoded with a random VAE from this list. The idea is to help make the UNet robust to any VAE version the end user might be using.

During training I noticed the model generating a lot of weird, high resolution patterns. It's hard to say the root cause. Could be moire patterns in the training data, since the dataset's resolution is so high. But I did use Lanczos interpolation so that should have been minimized. It could be inaccuracies in the latents, so I swapped over to just SDXL fp32 part way through training. Hard to say if that helped at all, or if any of that mattered. At this point I suspect that SDXL's VAE just isn't good enough for this task, where the majority of training images contain extreme amounts of detail. bigASP is very good at generating detailed, up close skin texture, but high frequency patterns like sheer nylon cause, I assume, the VAE to go crazy. More investigation is needed here. Or, god forbid, more training...

Of course, descriptive captions would be a nice addition in the future. That's likely to be one of my next big upgrades for future versions. JoyTag does a great job at tagging the images, so my goal is to do a lot of manual captioning to train a new LLaVa style model where the image embeddings come from both CLIP and JoyTag. The combo should help provide the LLM with both the broad generic understanding of CLIP and the detailed, uncensored tag based knowledge of JoyTag. Fingers crossed.

Finally, I want to mention the quality/aesthetic scoring model I used. I trained my own from scratch by manually rating images in a head-to-head fashion. Then I trained a model that takes as input the CLIP-B embeddings of two images and predicts the winner, based on this manual rating data. From that I could run ELO on a larger dataset to build a ranked dataset, and finally train a model that takes a single CLIP-B embedding and outputs a logit prediction across the 10 ranks.

This worked surprisingly well, given that I only rated a little over two thousand images. Definitely better for my task than the older aesthetic model that Stability uses. Blurry/etc images tended toward lower ranks, and higher quality photoshoot type photos tended towards the top.

That said, I think a lot more work could be done here. One big issue I want to avoid is having the quality model bias the Unet towards generating a specific "style" of image, like many of the big image gen models currently do. We all know that DALL-E look. So the goal of a good quality model is to ensure that it doesn't rank images based on a particular look/feel/style, but on a less biased metric of just "quality". Certainly a difficult and nebulous concept. To that end, I think my quality model could benefit from more rating data where images with very different content and styles are compared.

Conclusion

I hope all of these details help others who might go down this painful path.

Share
Sort by:
Best
Open comment sort options
u/fpgaminer avatar

Here's a dump of the custom training scripts: https://github.com/fpgaminer/bigasp-training

That's not intended to be something other people use, but just putting it out there for whatever it might be worth to others. It's the exact scripts used to train this version of bigASP, except the code for the scoring model, the watermark detector, and some misc stuff for data munging my local datasets. Frankly, the scoring model is a pile of Python Notebooks so... no one wants that.

u/vvorkingclass avatar

Thanks, I appreciate your willingness to share. It inspires others to do the same. It inspires me, too. Thank you.

I'd be interested in the notebook for scoring! No worries if its super messy, happy to figure it out :). Thanks for your work

u/fpgaminer avatar

I added them to the repo. God help you.

But yeah, dumped them as-is into the quality-arena folder, along with the watermark-detector and grid-detector.

Thank you unknown hero!

More replies
More replies
More replies

Mad respect for the work that obviously went into this. Great job!

u/no_witty_username avatar

These types of posts are what I really look forward to, so props for the write up. I make a lot of model's myself (thousands at this point) and am always reading anything interesting about people training processes. Currently I am baking a 50k image data set Lora and the preliminary results looking quite good. I am slowly pushing the data set counts and learning from mistakes along the way in hopes of one day baking a 1mil Lora myself. Anyways enough about my jabbering, I have some questions if you wont mind answering.

  1. How much did it cost you for that one training run? 8xH100 for 5 days. What platform did you train on? Were these serverless Gpu's or did you have access to a virtual machine and were able to choose whatever operating system you were training on and so on? Did you have any GPU's go tits up on you during that 5 day training run?

  2. What settings did you use? LR, dropout, batch size, normalization images, etc...? If you can share your json settings that be awesome, or share what you can. I used to train on Kohya_ss quite a lot and finding the proper hyperparameters always took a few days at least and even then I couldn't be 100% sure if the model would blow up 100k steps in or so... So I started using Onetrainer and prodigy to alleviate the issue, was just wondering if the hyperparameter search is a me only problem or if others also have to struggle with it for a few days before a training run.

  3. You said you had your own Vllm tag for you. Did you train the Vllm on your own data set and captions? If so, can you go a bit in to that. I am also considering training my own vllm, but I find information lacking online on how to go about doing that successfully.

  4. Did you train both the text encoders and Unet? Did you train the model on top of SDXL base or a custom finetuned base model? For captions, did you store them all in one json file or did each image have its own accompanying text file?

I am sure Ill come up with more questions later on, but Id love to know at least these for now. I'm downloading your model now to play around with. Cheers.

u/fpgaminer avatar

Just to be clear, I did a finetune, not a lora, so the hyperparameters and such are gonna be quite a bit different. Anyway:

The final run uses an effective batch size of 2048, no EMA, no offset noise, PyTorch's AMP with just float16 (not bfloat16), 1e-4 learning rate, AdamW, min-snr loss, 0.1 weight decay, cosine annealing with linear warmup for 100,000 training samples, 10% UCG rate, text encoder 1 training is enabled, text encoded 2 is kept frozen, min_snr_gamma=5, PyTorch GradScaler with an initial scaling of 65k, 0.9 beta1, 0.999 beta2, 1e-8 eps. Everything is initialized from SDXL 1.0.

How much did it cost you for that one training run? 8xH100 for 5 days. What platform did you train on? Were these serverless Gpu's or did you have access to a virtual machine and were able to choose whatever operating system you were training on and so on? Did you have any GPU's go tits up on you during that 5 day training run?

The 8xH100 system was like $27/hr to rent. I use Lambda. You just get a machine to SSH into. Over all the time that I've used them I think I've had maybe one GPU die during training, which of course they refunded and credited me for. Quite happy with them ... when they have GPUs available :P

You said you had your own Vllm tag for you. Did you train the Vllm on your own data set and captions? If so, can you go a bit in to that. I am also considering training my own vllm, but I find information lacking online on how to go about doing that successfully.

For this run I just used JoyTag, which is strictly a tagging model. Finetunes like Pony use a combination of a tagging model (or existing tags on the images), and feeding those tags to a VLLM. So you basically prompt the VLLM to incorporate the tags into its descriptions. That helps ground the VLLM to reduce hallucinations, increase accuracy, and align it with how users typically prompt (a mixture of descriptions and tags). No clue what VLLM Pony uses, but from what I hear CogVLM gets used a lot elsewhere. There's also LLaVa, Intern, ShareCaption, etc.

I definitely want to do a VLLM for future versions. For the first version, proving JoyTag as a useful tagger and building a photorealistic model that can be prompted with tags was my primary goal.

Did you train both the text encoders and Unet? Did you train the model on top of SDXL base or a custom finetuned base model? For captions, did you store them all in one json file or did each image have its own accompanying text file?

Just the Unet and TE1, I kept TE2 frozen.

Because I use a custom training script I can store the training data in more creative ways. All of my data churning to tag the images, generate the latents, detect watermarks, etc is all separate Python scripts. And they all operate off a single sqlite3 database to keep track of everything, filling in their respective columns. So to keep things simple I just setup my training script to read what it needs out of that same sqlite database. That way spinning up the training machine just involves transferring the latents, the sqlite3 database, and the training script.

u/no_witty_username avatar

Thanks for the reply. I wish you luck on finetuning the VLLM. I've been using CogVLM to caption my data set and indeed it is better then the competition. But no Vllm is good at NSFW captioning. None of them have been trained extensively on NSFW content so they don't know the nomenclature nor understand even the most basic positions. Often when a subject would be laying on her back it captioned that she was laying on her stomach and so on, that's why i wanted to finetune my own model... But alas we take what we get.

u/Electrical_Analyst_7 avatar

Without going NSFW, have you tried using gpt4o for labeling? I'd consider it because Juggernaut X xl uses them.

More replies
More replies
u/shawnington avatar

Normalization is misunderstood. People do poor training, and they cook what they are trying to train into a concept, and try and uncook it with normalization. You are just re-cooking it. Normalization is universally bad. If you need normalization, it's because there is a lack of specificity in the captioning of your data.

More replies

Real slam dunk of a post. Very much appreciate the level of detail as someone about to step into SDXL training.

Would you be willing to go into a bit of detail as to how you go about leveraging Joytag to automate captioning on such a large dataset? Did this require you to custom code a script for the task, or is this something that can be performed through existing tools? One of big missing pieces of the puzzle for me is the nuts and bolts of how you guys are captioning these big datasets with cutting edge vllms.

u/eraser851 avatar

TagGUI is a great tool for captioning.

I have and use it. Great tool. Would you know where to find information on how to use TagGUI with vllms that it is not already configured to use out-of-the-box? Such as Joytag.

u/eraser851 avatar

Unfortunately, I haven't been able to get custom models to run. You can experiment with specifying a directory for local models in the settings. They'll show up at the top of the list of models after a restart.

Thanks, I didn't know that. I'll poke at it. the installation directory contains auto_gptq, so I suspect that you have to used the models that have been quantized in that manner to run.

More replies
More replies
More replies
More replies
Edited

Most of this aligns with all my experience. However, I use natural language captions and don't put in score_x tags.

I prefer fp16 mixed precision because it better aligns with the actual magnitude of the weights and activations than bf16.

Edited

This is a really impressive write up. So many good details.

Thank you for writing and sharing it 👍🙏.

You said that it cost $27/hour to rent the GPU, so that's 27 * 24 * 5 = $3240 to train the model?!

u/SevereSituationAL avatar

Yes, A lot of model makers have patreons because making models is time consuming and cost a lot of money.

Indeed 🙏

More replies
u/fpgaminer avatar

Yeah... cheaper and more effective learning than college?

Yes, nothing beats learning by doing when it comes to the practical side of A.I. 😅

Absolutely.

More replies
More replies

Finally, some great content worthy of this subreddit.

u/AstraliteHeart avatar

10% of the time the prompt is dropped completely, being set to an empty string

So a big question - did that actually work? I guess no other option than to train model once again with no dropout. I stopped doing it in V6 as one of the attempts to simplify training so at bare minimum it's not required but measuring impact on quality would've been great.

Pure fp16 training did not work for me

Can you please expand on what didn't work?

For every training sample, it picks a random integer in this distribution to determine how many tags it should use for this training sample. It shuffles the tags on the image and then truncates them to that number.

That is very cool and I never thought about it, Pony does shuffle tags but always dumps anything I have.

cosine annealing

it's interesting you went with cosine instead of const, it made sense for Pony because of really high initial LR but I assumed for base SDXL rate const would make more sense. What was your reasoning?

 That said, 30M training samples, as large as it is to me, pales in comparison to Pony XL which, as far as I understand, did roughly the same number of epochs just with 6M! 

When you say 30M samples, do you mean in 20e the model saw 30M images, i.e. your training dataset is 1.5M unique images?

A stable training loss is also logged

To my shame I am unaware of this, any code references?

From that I could run ELO on a larger dataset to build a ranked dataset,

This is a new to me. I use ELO now for human feedback but the actual scoring model is a dumb "clip => score" one (although I am experimenting with ViT features now). My naïve assumption is that self play like this may not be efficient but very interested in your reasoning.

u/fpgaminer avatar

bows with respect

So a big question - did that actually work? I guess no other option than to train model once again with no dropout. I stopped doing it in V6 as one of the attempts to simplify training so at bare minimum it's not required but measuring impact on quality would've been great.

It's a good question. I dunno if the dropping does something ... special? to the model to enable CFG, or if it was historically only done because we didn't have this idea of "negative prompts" yet. The exact "why" of why CFG works is a mystery to me, so I've always internalized it as "moving" the model away from "something". With an empty prompt, that's likely to be intended as the "average image" and thus drive the gen towards "better than average". But who knows.

It can definitely be tried ... just expensive as hell.

Pure fp16 training did not work for me Can you please expand on what didn't work?

The validation and stable loss drops for the first 100k samples, but then flat lines. I ran it out to ~10M training samples, but it just remained flat. Mixed training worked fine and validation loss continued to decrease. I played a little with the gradscaler, just plain loss scaling, adam eps, adam beta2, but no joy.

I'd love to get pure fp16 working, since it's a nice memory saving and a good clip faster iteration speed. And it's probably possible; IIUC LLaMA is trained with pure fp16. But for me it wasn't worth the compute I was wasting trying to get it working.

it's interesting you went with cosine instead of const, it made sense for Pony because of really high initial LR but I assumed for base SDXL rate const would make more sense. What was your reasoning?

Nerves. Animagine used cosine annealing and they used a dataset with the same size as mine, so I decided to use their hyperparameters as the best starting point. Only big tweak I made was increasing the batch size and LR to match.

I'm debating doing another run with constant LR to compare. Right now I have an ablation with a lower WD running.

When you say 30M samples, do you mean in 20e the model saw 30M images, i.e. your training dataset is 1.5M unique images?

Yes, roughly. Something like 20.8 epochs on a 1,440,000 dataset.

To my shame I am unaware of this, any code references?

https://github.com/fpgaminer/bigasp-training/blob/53f31802531774a392eb9b54def8f74ca6cd679e/train.py#L209-L214

https://github.com/fpgaminer/bigasp-training/blob/53f31802531774a392eb9b54def8f74ca6cd679e/train.py#L892

That's the important bits. It's essentially just a second validation pass, but with the data being a fixed slice from the training data. And of course the seed is fixed, just like validation, so all the random factors (like timesteps) are fixed. That way it's nice a stable; just as smooth as the validation loss curve for me. And should be very reflective of how well the model is learning the training data, versus how well it's generalizing using the validation loss. Sample images are always king, of course.

This is a new to me. I use ELO now for human feedback but the actual scoring model is a dumb "clip => score" one (although I am experimenting with ViT features now). My naïve assumption is that self play like this may not be efficient but very interested in your reasoning.

I dunno, I probably had some good reasoning. But it actually just comes down to: I'm one human and I'm trying to get a cool image gen model trained not a quality model. For the first iteration, good enough will be good enough.

I figured a human decided ELO would require a lot of manual ratings. I wanted a diverse set of images in the arena, like 64k. So, I'll do a thousand manual head-to-head ratings and throw thoughts and prayers at CLIP.

I don't think my quality model is great, but it's better than the old aesthetic model, it got all the junk images into score_0 for pruning, and it has the general trend I was looking for across the buckets. So, small win for self play? shrug

u/AstraliteHeart avatar

Thank you for answering!

I dunno if the dropping does something ... special?

My understanding is that dropping captions generally align model with "average" dataset image (which may improve quality) but with presence of aesthetic tags it's not helping much.

Animagine used cosine annealing and they used a dataset with the same size as mine, so I decided to use their hyperparameters as the best starting point. 

I see, thank you. It is really annoying we don't have a good answer. My plan for v7 is to try constant and see how it works.

And should be very reflective of how well the model is learning the training data, versus how well it's generalizing using the validation loss. Sample images are always king, of course.

You said you were generating samples for validation on other machine, but wouldn't calculation of stable loss pretty much do the same (and eat memory?)

u/fpgaminer avatar

You said you were generating samples for validation on other machine, but wouldn't calculation of stable loss pretty much do the same (and eat memory?)

It's just a loss, no VAE needed. So no extra memory.

But yeah, does consume compute. 2048 forward passes vs 80 per sample image. I dunno, having those metrics is useful to me. But maybe it's just because I come from training more standard ML models and so I'm used to watching training and validation loss. Seems difficult to measure overtraining without either doing a lot of samples to check diversity of gens, or I guess maybe FID score.

u/AstraliteHeart avatar

I agree on lack of good metrics, loss was pretty useless in my training runs so I just eyeballed results based on about 100 prompts for every epoch.

u/Aware_Photograph_585 avatar

I'm running tests on FID/KID/IS/LPIPS/HPSv2 to score the model as it's training. Using 1% of 25k dataset for these tests and validation loss. Halfway through training, so nothing confirmed yet. HPSv2 trended down then flat lined for the 1st half of training, and looks about to come back up. If so, could be a could a good metric. Nothing of note from the others yet.

More replies
u/AnOnlineHandle avatar

I've tried using two different finetuned checkpoints to do CFG (juggernaut which is good with empty prompts for the blank prediction, and a 1.5 finetune trained on my dataset for the captioned prediction). It didn't work at all, despite both models being nearly identical at the higher level (and in a weights comparison).

It seems that you actually want the unconditional prediction to do most of the heavy lifting, and for the conditioned prediction to be quite similar due to coming from the same model, just giving a hint about what part of the prediction to amplify to make it better fit the prompt. A strong unconditional prediction implies a model which is generally good at denoising, it seems.

That being said, it also seems that the text encoder should be frozen whenever the prompt is dropped, as you don't want the whole text encoder to fit to your uncaptioned images (where a blank prompt is still filled with padding tokens which are just as trainable as any other tokens).

More replies
More replies
u/fpgaminer avatar

Follow up on the cosine vs const schedule. I lit some money on fire:

https://api.wandb.ai/links/hungerstrike/9kr1o2jz

Const started to diverge between 7 and 10 epochs (10M samples to 15M samples), getting worse than cosine on validation loss. I stopped the run shortly after.

That follows my experience with other models like ViT, where cosine performs better when the epoch count is high. Constant really only seems to work well if the epoch is 1 or close to it, like the foundational training of SDXL itself, or LLMs.

Kinda makes me wonder if a 1Cycle would perform even better, or at least more efficiently. But I've burned enough money for now. Maybe next version.

More replies
u/shawnington avatar
Edited

Im just baffled every time someone tries to throw tons of autocaptioned data and compute at a problem. Literally anyone with any experience knows that 100k very well captioned images will outperform 1.4m auto captioned images that don't reflect the data accurately.

1 miscaptioned image does more damage than 10 correctly captioned images do good.

The state of autocaptioners is really bad, they are not advanced, they are not accurate, and the destroy datasets.

Ive had the one you used caption a photo at beach as someone in a bathroom with a clearly visible sink, that was actually a sailboat.

That kind of inaccurate captioning on even one image, does a lot of damage. I think a limit that you are rightfully exploring is language precision, and there are actual probably limits in terms of english or other languages in describing scenes non-ambiguously. Im not sure increasing the ambiguity with synonym swapping will increase the understanding or act like token merging during training, and there is a token merging limit where quality starts to decay. Not sure what you experience was with that vs runs without it. Its not bad in theory, but if its also swapping out on miscaptioned data, its going to bake in error also.

But also, props for sharing your process, this is how people learn, and definitely there is not very much information out there.

Also, a lot more needs to be read about learning rate schedules. On completely unrelated models (domain specific), different learning rate schedules have been the difference between taking 3 days to train, and 15 min to train in my own experience. There is much to gain in ways to schedule learn rates.

Everyone trains with really basic schedules that are very very inefficient.

How is this baffling? Not everyone has the resources to pay 10 people to produce high-quality manual captions to 100k images, let alone 1.5m. Hobbyists don't have commercial or researcher-level resources. We do the best we can with the resources we can muster.

I wholesale disagree with your assessment. A few miscaptioned images on a fine-tune has very little impact to the quality of the end-product. I have fine-tuned the same dataset of 20k images at least 10 times. I have made incremental changes to the captions on many of those runs. I can see the changes that are wrought. I can see the effect of errors. It is certainly there, but it doesn't produce "a lot of damage".

I really get the impression that you've never actually performed a large-scale fine-tune and are most experienced with LORAs. Smaller scale training is more greatly affected by caption issues and nothing over 5k images trains in 15 minutes unless you're using a bank of commercial gpus, literally anyone with any experience knows that.

u/shawnington avatar

It depends on the magnitude of the error. captioning a wolf as a dog a few times, is clearly not as detrimental as captioning it as a motorcycle or an airplane.

Also, I referenced the training of an unrelated domain specific model, meaning having nothing to do with generative imaging at all, but still u-net based architecture, trained for document recognition, and segmentation. I was referencing the effects different learning rate schedules had for training that network, which was trained on a dataset of over 1 billion synthetic documents.

The point being, experimenting with training schedules can have very significant impacts on how quickly and how much compute is required to achieve the results you are after.

I brought that up because especially for hobbyists with limited resources, most people are sticking to your basic constant lr, or something with warmup, and have never tried something like cosine annealing, or other lr schedules which can drastically speed up training, and reduce the number of epochs required.

More replies
u/Xamanthas avatar
Edited

How do you propose to humanly tag 100k~ images with no inaccuracies? Happy to hear more ideas, as its the same problem facing us after we finish training the base model with 5.1M images.

You cant trust agencies because they wont know the ins and outs. (also if you do have a source I would love to read it, not that I disagree with you)

u/shawnington avatar
Edited

There is no easy answer, and it is why datasets are the most valuable commodity in AI. Just look at the difference between Playground 2.5 and SDXL, it consistently outperforms it in every preference, and prompt adherence test. It is the SDXL architecture, retrained on higher quality data.

I think a compromise can be made with auto-captioning by using a cohort of agents approach, have several auto-captioner caption an image, and establish a threshold of difference in captioning that requires human review or excludes it from the dataset, or using another model to write a consensus opinion.

The conversation should go deeper than auto-caption vs human captioned though, because captioning style, and what needs to be captioned how, is an ongoing discussion.

I think there is some merit to the idea of varying the caption for an image and having it repeated in the dataset, similar to how including different crops is beneficial.

Im not sure if using synonyms for the same image is the way to do it. Like I said, that sounds like it might be baking in token merging, and token merging is known to reduce quality slightly. However Im not sure if this would be the case, or if it would reduce the ambiguity that is naturally present in captioning in an ambiguous language such as english, and actually make it better at prompt adherence. That is something Id be very interested in knowing.

u/Xamanthas avatar
Edited

Cheers for taking the time to respond, we were mucking around with that idea too (multiple agents). I'll try to keep you mind if we manage to develop anything worthwhile, though if you are interested in helping with an NSFW artwork and anime (but non suspect) focused community finetune do let me know, currently we are focused on refining the dataset and trying to build solutions.

As for varying the caption the approach we were wanting to experiment with is tagging and then the same image with natural lang. If that fails then only a subset gets naturally tagged.

u/shawnington avatar
Edited

Id be willing to offer any help or ideas.

Something Id be very curious to try considering text encoder training was done, and you are talking about splitting captions into two styles, would be to have split captions, so that one that is just strictly a list of everything in the image eg "sailboat, wooden dock, rocks, beach, waves" for a Clip-L encoder training pass, and a natural language caption like "A sailboat moored at a wooden dock by a rocky beach with waves breaking on the shore" going to a Clip-G training pass.

The idea behind this being obviously that since SDXL has Clip-L and Clip-G, and Clip-L is better with word salad, and Clip-G is better with natural language, you feed Clip-L the word salad description of basically semantic segmentation tags of the image, and natural language to Clip-G.

Ive also been curious since segment anything was released, if there is any value in trying to create a list of captions from the segmented items. So basically masking the image by segment and having the auto-captioner caption each segment individually, and then condense that down, or train them as crops of the image. That might be a way to effectively increase the token length of the context its learning about a particular image.

Ive always suspected that the limits of token lengths can be kinda gotten around by splitting descriptions up between the two text-encoders like this for training.

I do disagree with the prompt length variation though, unless there is a lower bound as to where everything in the image is captioned, but there length variation just covers variations in verbosity. Intentionally omitting things or concepts from being captioned is how biases get baked into concepts.

So unless things are being omitted specifically to bake them into the concept, in general, its been established that using all available tokens in training is better (see pixart-sigma, they describe fairly in-depth the deference moving from 75 token to 300 token captions makes in prompt adherence).

More replies
More replies
More replies
u/dvztimes avatar

If I wanted to finetune a test model, is there a good guide or place to start? How many images would I need. I've trained more than 100 lorans on my 4090 but would like to try my hand at a model.

Thank you.

u/shawnington avatar

The is an - it depends answer. It depends on what you want the model to do, how specific you want it to be, what kinds/ how many styles / concepts you want to add / augment.

More replies

Chat GPT4o is pretty darn amazing at captioning accurately. I keep thinking that Civitai.com have a gold mine of millions of human generated and preference ranked promts/captioned images.

Do they really though ? They have lots images with quality scores from humans. But those humans aren’t looking at the prompt and considering how close the image is to what the prompt asked for. They are just scrolling through a bunch of images and pressing “like” occasionally.

And so, so often on an image you will look at the prompt and realise the person who generated the image was just throwing shit at the wall and happened to get lucky with one image. So the prompt will be full of stuff not represented in the image (they’ll have four different concepts in the prompt, and sometimes what is in the image won’t even be what is in the prompt).

So I think they have a fair bit of data on what images humans like. But I don’t think they have much useful data to tie that back to what the model was asked to generate.

More replies
More replies
u/Aware_Photograph_585 avatar

Nice work. Your hyperparmeters align well with my own test results. Also, thanks for including the validation loss curve, confirms that my curve looks correct.

Original/Target size are useful if your upscaling images before training. Mostly useful if you've got a bunch of 256/512 images that you want to train with. Not something you'd need with your high resolution dataset. I'm running some test on it now.

Crop_top_left helps prevent bad crops from creating out of frame generations. Also, probably not a problem you'd have with your dataset, assuming your crop script didn't do crazy croppings.

Stable training loss is a new one for me, I'll need to add that to my script.

Regarding the vae, I think it's a major weak point. I think a better vae decoder would help a lot. I'd like to fine-tune the vae decoder, but I can't find information on how to do it.

u/fpgaminer avatar

Glad the loss curve was helpful. No one ever seems to publish those, leaving everyone else looking at theirs "Is ... is it supposed to do that!?"

Original/Target size are useful if your upscaling images before training. Mostly useful if you've got a bunch of 256/512 images that you want to train with. Not something you'd need with your high resolution dataset. I'm running some test on it now.

Yeah it's a mysterious one. I've tried tweaking those conditions during inference on other models and it kinda sorta adds more details? So I figured it was good to include in my training. But I'm a little worried that stock SDXL has never seen resolutions this large... (a lot of images in my dataset through PIL warnings about zip bombs because they're so large).

Regarding the vae, I think it's a major weak point. I think a better vae decoder would help a lot. I'd like to fine-tune the vae decoder, but I can't find information on how to do it.

I think others have tried finetuning the VAE, and there's a few variations out there. Honestly I think the arch itself just won't be able to do it. OpenAI is probably on the right track with their diffusion based decoder, but from what I've heard it's a memory hog so isn't useful for local gens.

More replies
u/eraser851 avatar
Edited

I've been diving down into the descriptive captioning rabbit hole.

Check out TagGUI.

The Xcomposer2 model does a decent job with NSFW images.

I've been doing it in a two step process: tag the images first with something like wd-swinv2-tagger (or JoyTag).

Then change the model to Xcomposer2 and include in the prompt something like "Use the following tags for context: {tags}" (You'll likely spend lots of time tweaking your prompt for the VLM).

CogAgent & (the new) CogVLM2 does fantastic with SFW content.

I'm still experimenting on having a dataset with captions AND tags, or just one or the other. I believe SD3 has something like half captioned images, half tagged images. I'm also toying with caption enhancing, like how Dalle3 was trained. Still trying to find the right LLM to do that.

Woah, I probably did only get half of this. Thankfully smart people are horny too and create sexy models for us folks.

u/Aware_Photograph_585 avatar

What framework & libraries did you use for multi-gpu? Accelerate, deepspeed, FSDP, DDP, etc?

u/fpgaminer avatar

Just good ole PyTorch DDP

u/Aware_Photograph_585 avatar

Wish I could use PyTorch DDP, but SDXL won't train with DDP & mixed precision fp16 on 24GB vram.

I wrote my own multi-gpu SDXL trainer, currently using accelerate deepspeed zero stage 1 w/ cpu offset to train on 2 RTX4090s, getting ~7imgs/sec at mixed resolution 768-1024.

Accelerate & deepspeed have some minor issues, but it's works well enough. I'd like to re-write it in pure PyTorch FSDP, but I'm still a novice programmer. I'm going to read through your script in hopes of learning some new stuff. Thanks for sharing it.

More replies
More replies
u/nupsss avatar

Sounds like fun! Any specific reason that moved you into doing this?

I am curious- how do you handle such gigantic image datasets? Cant just host a cloud and spend the first few hours uploading images, right?😅

Thanks for this.

Now can you give us the link to your Model so that we can see the results?

u/fpgaminer avatar

It's a nsfw model so I didn't link it, but just search civit for bigasp.

More replies

Interesting reading.
A question and a comment:
You said you wanted to avoid your model being "only one style".

Why not explicitly identify styles in your input images, then split up highly styled images into storage for later LORA training.
Use general-case, non "styled" images for the larger training, then provide loras for any particular desired style. I would think that should result in the most comprehensive flexible base model, while at the same time being able to provide a wide variety of styles to the user.

You seem to be impressed by ponyxl, so... seems worth mentioning thats basically how pony gets used.

u/terrariyum avatar

This may well overturn the entire industry of "realistic pony" merges

u/HardenMuhPants avatar

Even using commercial hardware I shudder at the though of cache a million latents.

TBH the VAE isn't very big in terms of memory, especially in the context of 40GB+ cards. It's only 335MB, or half that in FP16, and its safe to use it in full FP16 during training.

Also it only pays off so much on compute time if you run enough epochs. If you run 20 epochs, you run the VAE on every image 20 times. That may seem like a lot but the VAE is very fast compared to the Unet and the backwards pass. The VRAM used during actual forward for activations, etc. is moot as it is all ephemeral, and it is run out of sync with unet forward/backward.

I guess in the context of renting GPUs, you can precompute the latents on a much cheaper PC nearly for free ahead of time, and you save a bit of time transferring your data. Depending on the GPU service you might be able to locate your data in their DC ahead of actually renting the GPUs making that sort of moot.

It'd be low on my list to optimize, but I suppose it does save some dollars.

u/Venthorn avatar

It's only 335MB, or half that in FP16, and its safe to use it in full FP16 during training.

You have to include the cost of storing the full image in memory (x batch size) as well, instead of just latent image x batch size. And associated time cost for the extra sized reads from disk.

Edited

Torch dataloader reads ahead of training steps on separate threads from the main training thread, so this adds zero latency to the training process. It truly does not matter one bit.

VAE encoding does add latency since its more GPU compute and you have to wait for the VAE to execute on the batch.

The RGB images are pretty much trivial amounts of VRAM, even with high target resolution and large batch size.

More replies
u/fpgaminer avatar

The real kicker is the 1TB of images vs 100GB of latents :P. And VAE is more of a pain for SDXL than it was for SD1.5. I definitely don't remember it being a consideration for 1.5, but it had a big hit on training speed for SDXL. For me, at least. Besides, A100s and H100s are starved for batch size, so the more memory I can use for training batches the better.

u/SevereSituationAL avatar

Great job. Though the timing is so close to SD3's release (soon hopefully). A lot of people are still waiting for SD3 and are holding off on training SDXL. Still, thanks for sharing your knowledge.

More replies
More replies
More replies
u/codyp avatar

Thank you for this.

u/aerilyn235 avatar

Where did you get your data from?

Amazing information, thanks!

u/ImplementComplex8762 avatar

so is it better than pony?

u/One-Culture4035 avatar

If you are fine-tuning instead of training LoRA ,ControlNet and ipadapter plugins may not be well compatible,did you have a try?

u/fpgaminer avatar
Edited

I haven't tried those yet, no. I've been down in the training trenches, haven't had much time to exercise the model yet.

More replies
u/Saltfireflame avatar

Absolutely great, thank you.

u/AmazinglyObliviouse avatar

Nice post, and I'm enjoying the model too. Do you have a discord server to follow, in case you decide to train SD3 too? I'd like to follow that :)

u/fpgaminer avatar

Nah, I don't want the PITA of a discord server. I post my models on civit, so probably easiest to just follow that for updates down the road.

More replies

ok, so where's the model, i want to try it

u/fpgaminer avatar

It's on civit under the name bigasp. I won't link here, since it's a nsfw model.

will you train again with sd3 with these datasetes

More replies
More replies
u/Devajyoti1231 avatar

Amazing post .

I don't understand the image rating thing, like do you rate it yourself or tagger like JoyTag can do it automatically (can't think about rating 1m images manually).

Also by training prompt do you mean image captions/tags?

u/fpgaminer avatar

I don't understand the image rating thing, like do you rate it yourself

I only had to rate 2,000 or so images manually in "head to head" matches. I have a UI that shows me two random images and I pick the "better" one ... two thousand times. From that and with a bit of magic I can train a model to automatically rank images from 0 to 9. So that automated model is the one that ranked all of the images used to train bigASP.

Also by training prompt do you mean image captions/tags?

If I were to take an image of a cat and run it through a tagging model, or manually tag it, I'd get tons of tags for that image. Cat, outside, day, 4k, grass, tree, sun, sky, etc etc.

But when a user generates an image of cat, they usually aren't that specific. A handful of tags, maybe. They prompt for a few things they care about, and leave it to the model to fill in the rest. If they don't care if it's day or night, they probably won't include either tag and expect the model to randomly choose. Right?

So all the images in my dataset have all these tags, but I believe it would be a bad idea to include them all as-is as the prompt for the image during training. So I have this concept of a "training prompt", which is generated randomly on the fly by the dataloading code. It shuffles the tags, picks a random length, drops tags to fit that length, incorporates score tags, applies synonyms, etc. It's job is to generate prompts like a user would.

That way, hopefully, the model doesn't become hard to control and hyperspecific. Instead, it's used to the way users prompt, and generates a good diversity of images while still following what the user asked for.

And because the training prompt is made on the fly, and is random, that means that between epochs the model sees totally different prompts for the same image, helping to prevent overfitting.

I hope that makes a little more sense. I actually shared my training scripts in another comment, so the code is available for this in the data.py file.

u/Devajyoti1231 avatar

Thank you so much.

More replies
More replies

Don't use the original VAE. There's a reason they made a fixed VAE.

Are you going to release the weights for this model? Do you find that it is better than the popular SDXL finetunes like JuggernautXL?

u/Aware_Photograph_585 avatar

Could you provide some more information about your scoring model? I've been trying to figure out how to create a scoring model, but don't know where to find ranked images for training the model.

Mainly, I'm just looking for a way to automatically remove low quality images from my dataset. I may just start a folder to collect all the images I remove and train a model against that. Not sure what else I can try.

u/fpgaminer avatar

For the ranking itself, I just ranked 1k-2k images myself (in head-to-head matches). The scoring models use CLIP embeddings, so it doesn't take too much data to train them to "good enough". Though you can certainly automate the process to build a seed dataset that you can build off of. Many papers have shown that GPT4V matches human performance for ranking the quality of images. And there's the older aesthetic model (https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main) that you can seed from.

I actually end up with a few different databases to deal with all of this. Head-to-head image rankings to directly train the scoring models. "Junk" and "Not junk" for training a "junk detector" model. "Watermark" and "Not watermark" which originally I used for training a watermark detector but stock OWLv2 was good enough so this just became the validation dataset for comparing watermark model performance.

All of that data was done by hand. Took maybe a couple days of work. Not full days.

CLIP can help a lot here. I preencode the CLIP-B embeddings for all images in my dataset. So I set up a Python notebook, and can do a quick text similarity search through the dataset for searches like "thumbnail", "advertisement", "grid", "video preview", etc. That let's me quickly find and manually tag both positive and negatives for the junk dataset. I also make sure to manually comb through and tag without the assistance of that search, in case there are any biases in CLIP's embeddings.

For junk images, I found it fairly easy to spot them even if the images were small (256x256). So I have the Python notebook grab 100 random images from my dataset, use the CLIP similarity scores to sort them from "most junk like" to "least junk like", and then arrange them into a big grid of images with "junk" and "not junk" images underneath each. Takes 60 seconds to go through, hit 10 junk buttons, 10 not junk buttons, rinse and repeat.

I built up maybe 200 or so junk/not junk ratings that way. Trained a model. Then I go back and do the same thing, but I use the trained model to rank images. This time I focus heavily on images where the model is making mistakes or isn't confident enough. Since I have this sorted grid of images it's easy to do.

Tag another 50, retrain, tag again, retrain, etc until the model isn't making any more mistakes. The cycles are fast, so it's tedious but doesn't take long.

Now, you can either use the junk detector directly to weed images out of your dataset. Or, in my case, I just used it to "validate" the scoring model and assist in finding junk images to throw into my head-to-head scoring UI. The scoring model is a very similar workflow. Score 100 images, train the model, analyze how it's performing. Use the junk model to find stuff it isn't ranking down at the bottom, etc. Score 100 more, retrain, repeat.

The "models" used here are all just models that take CLIP embeddings as input, do a linear projection or a two layer MLP if I'm feeling fancy, to a softmax. Again, all the CLIP embeddings for my dataset are precomputed, so these models take no time to train. I use a bit of dropout, 0.1 weight decay, adamw, and the LR finder + 1Cycle strategy to train them along with a validation split to measure performance. Quick and dirty.

EDIT: One quick addendum. Almost all of this work, like I mentioned, is done in Python Notebooks, which make it easy to quickly stand up UIs for tagging/ranking/etc images and for visualising images->predictions.

u/Aware_Photograph_585 avatar

Wow. Thanks for the detail responses. Looks like I need to learn how to use CLIP and work more with python notebooks, instead of py files like I've been doing.

Again. Appreciate all the in the info and help you've given the community!

More replies
More replies
u/alexds9 avatar

What do you think about chronic issues that SDXL models have with positioning bones at the wrong places, missing/extra belly buttons, deformed skin textures, and hands looking like origami?
It seems to me that SD 1.5 understands the body better than XL on a fundamental level if we ignore occasional limbs salad for a moment.

awesome documentation and thanks for same

u/LD2WDavid avatar

Super post. Enjoyed a lot reading it.

I have the chance to see a bit of this process from some partners and it's insane. How much time did you put in the scoring model? Probably key part of this.

Really really cool. Saved for further re-readings. Will take a look at joytag. I liked the mutation part giving people prompt "freedom".

u/fpgaminer avatar

I didn't spend much time on the scoring model, which is why I think it could be improved to reduce some of its bias towards "photoshoot" style photos. I spent maybe two days manually scoring images?

Vision models and some augmentations helped to bulk up that dataset. For example, I randomly blur and JPEG compress the images in the dataset, where the "corrupted" version of the image always loses to the original.

More replies