-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathreadme.md.bak
More file actions
61 lines (39 loc) · 2.48 KB
/
readme.md.bak
File metadata and controls
61 lines (39 loc) · 2.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# LaWa: Using Latent Space for In-Generation Image Watermarking
This is the notebook for reproducing the LaWa results. The paper has been accepted for publication in ECCV2024.
<p align="center">
<center>
<img src="https://vbdai-notebooks.obs.cn-north-4.myhuaweicloud.com/lawa/framework.png" alt="alt text" width="1000">
</center>
</p>
## Download and extract the code
```python
!wget https://vbdai-notebooks.obs.cn-north-4.myhuaweicloud.com/lawa/code.zip
!unzip -qo code.zip
```
## Install Required Packages
We have tested our code with python 3.8.17, pytorch 2.0.1, torchvision 0.15.2, and cuda 11.3. You can reproduce the environment using conda by running
```python
!conda env create -f environment.yml
!conda activate LaWa
```
## Inference
Run the following script to download our pretrained modified decoder as well as the original decoder. These weights correspond to the KL-f8 auto-encoder model and 48-bit watermarks.
```python
!bash download.sh
```
Model weights will be saved to `weights/LaWa/last.ckpt` and `weights/first_stage_models/first_stage_KL-f8.ckpt`.
Furthermore, download weights of Stable Diffusion v1.4 model from [here](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt) and save it to `weights/stable-diffusion-v1/model.ckpt`.
To generate watermarked images using Stable Diffusion and LaWa, run:
```python
!python inference_AIGC.py --config configs/SD14_LaWa_inference.yaml --prompt "A white plate of food on a dining table" --message_len 48 --message '110111001110110001000000011101000110011100110101' --outdir results/SD14_LaWa/txt2img-samples
```
This will save the generated original and watermarked images as well as the difference image in `results/SD14_LaWa/txt2img-samples`. Also, `results/SD14_LaWa/test_results_quality.csv` and `results/SD14_LaWa/test_results_attacks.csv` are generated, which contain a summary of the visual quality of the watermarked image as well as its robustness to attacks.
## Train your own model
### Data Preparation
Download the MIRFlickR dataset from the official website. `data/train_100k.csv` contains the list of images we have used for training. In the config file `configs/SD14_LaWa.yaml`, adjust the path to images folder of the dataset under the data_dir of train and validation datasets.
### Train
You can train your modified decoder using:
```python
!python train.py --message_len 48 --config configs/SD14_LaWa.yaml --batch_size 8 --max_epochs 40 --learning_rate 0.00006
```
Batch size 8 fits on a 32GB GPU.