In [1]:
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
import torch
import json
from utils.misc import *
import gc
import pandas as pd
import pandas as pd
from tqdm import tqdm
from utils.clip_eval import CLIPEvaluator
from transformers import AutoImageProcessor, AutoModel
from utils.layout_control import *
from transformers import CLIPTextModel
from matplotlib import pyplot as plt
import math


BACKGROUND = ['barn']
OBJ = ['cat', 'dog', 'chair', 'table', 'flower', 'wooden_pot']


def init_generative_model(args, device):
    """
    Initialize the model
    params:
        args: argparse.Namespace
        device: str
    """
    if args.model_name == "sd-v1-5":
        model_id = "runwayml/stable-diffusion-v1-5"
        pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)
    elif args.model_name == 'textual_inversion':
        model_id = "runwayml/stable-diffusion-v1-5"
        if args.spatial_inversion:
            pipe = Layout_Control(model_id=model_id, device=device)
        else:
            pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)
    elif args.model_name == 'custom_diffusion' or args.model_name == 'dreambooth':
        model_id = "CompVis/stable-diffusion-v1-4"
        if args.spatial_inversion:
            pipe = Layout_Control(model_id=model_id, device=device)
        else:
            pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)
    return pipe


def check_identifier_token_dir(checkpoint, concept_str):
    if os.path.exists(os.path.join(checkpoint, concept_str)):
        dir = os.path.join(checkpoint, concept_str)
    elif os.path.exists(checkpoint.format(concept_str)):
        dir = checkpoint.format(concept_str)
    else:
        dir = checkpoint
    return dir


def load_trained_weights(args, pipe, checkpoint, concepts_str, c_identifier):
    """
    Load additional weights for the model
    params:
        model_name: str
        pipe: StableDiffusionPipeline
        checkpoint: str, where the additional unique identifier weights are saved
        concepts_str: str, the concepts used to generate images
    """
    print(f"loading unique identifier weights for {concepts_str} ...")
    if args.model_name == 'sd-v1-5':
        return pipe
    if args.model_name == 'textual_inversion':
        for cls in concepts_str.split(','):
            dir = check_identifier_token_dir(checkpoint, cls)
            pipe.load_textual_inversion(os.path.join(dir, 'learned_embeds.bin'))
    elif args.model_name == 'custom_diffusion' and args.spatial_inversion == False:
        dir = check_identifier_token_dir(checkpoint, concepts_str)
        print(f"loading unique identifier weights from {dir} ...")
        pipe.unet.load_attn_procs(dir, weight_name="pytorch_custom_diffusion_weights.bin")
        for cls in concepts_str.split(','):
            pipe.load_textual_inversion(dir, weight_name=f"{c_identifier[cls]}.bin")
    elif args.model_name == 'custom_diffusion' and args.spatial_inversion == True:
        dir = check_identifier_token_dir(checkpoint, concepts_str)
        pipe.load_attn_procs(dir, weight_name="pytorch_custom_diffusion_weights.bin")
        for cls in concepts_str.split(','):
            pipe.load_textual_inversion(os.path.join(dir, f"{c_identifier[cls]}.bin"))
    elif args.model_name == 'dreambooth' and args.spatial_inversion == False:
        dir = check_identifier_token_dir(checkpoint, concepts_str)
        unet = UNet2DConditionModel.from_pretrained(os.path.join(dir, 'unet')).to(pipe.device)
        pipe.unet = unet
        if os.path.exists(os.path.join(dir, 'text_encoder')):
            text_encoder = CLIPTextModel.from_pretrained(os.path.join(dir, 'text_encoder')).to(pipe.device)
            pipe.text_encoer = text_encoder
        for cls in concepts_str.split(','):
            if os.path.exists(os.path.join(dir, f"{c_identifier[cls]}.bin")):
                print(f"loading unique identifier weights from {c_identifier[cls]}.bin ...")
                pipe.load_textual_inversion(os.path.join(dir, f"{c_identifier[cls]}.bin"))
    elif args.model_name == 'dreambooth' and args.spatial_inversion == True:
        dir = check_identifier_token_dir(checkpoint, concepts_str)
        pipe.load_dreambooth_weights(dir)
        for cls in concepts_str.split(','):
            pipe.load_textual_inversion(os.path.join(dir, f"{c_identifier[cls]}.bin"))
    else:
        raise ValueError(f'Unknown model name {args.model_name}')
    return pipe


def save_img(im, prompts, save_dir, name):
    """
    Save images
    params:
        im: numpy array
        prompts: list
        save_dir: str
        name: str
    """
    save_path = os.path.join(save_dir, f'{prompts}/{name}.jpg')
    check_mk_file_dir(save_path)
    if isinstance(im, np.ndarray):
        Image.fromarray(im).save(save_path)
    elif isinstance(im, Image.Image):
        im.save(save_path)
    else:
        raise TypeError(f'Unknown type {type(im)}') 

def get_reference_images(concepts_str, src_img_dir):
    src_imgs = {}
    for concept in concepts_str.split(','):
        src_imgs[concept] = []
        for img_path in os.listdir(os.path.join(src_img_dir, concept)):
            src_imgs[concept].append(Image.open(os.path.join(src_img_dir, concept, img_path)))
    return src_imgs


def edit_original_prompt(prompt, c_identifier, keys, mode='replace'):
    """
    Edit the original prompt to the prompt with identifiers
    params:
        prompt: str
        c_identifier: dict, concepts and identifiers, e.g. {'cat': '<cute-cat>'}
        keys: concept list, e.g. ['cat', 'dog']
        mode: str, 'replace' or 'insert'
    returns:
        prompt: str
    """
    if mode != 'none':
        for key in keys:
            if '_' in key:
                replace_str = key.replace('_', ' ')
            else:
                replace_str = key
            if replace_str in prompt:
                if mode == 'replace':
                    prompt = prompt.replace(replace_str, c_identifier[key])
                elif mode == 'insert':
                    prompt = prompt.replace(replace_str, f'{c_identifier[key]} {replace_str}')
            else:
                raise ValueError(f'{replace_str} not in prompt {prompt}')
    return prompt.strip()


def get_input_dict(prompt, edited_prompt, concepts_str, c_identifier):
    """
    Get the input dict for the layout control generation
    params:
        prompt: str
        edited_prompt: str
        concepts_str: str
        c_identifier: dict, concepts and identifiers, e.g. {'cat': '<cute-cat>'}
        layout: dict, {prompt: [bbox]}, e.g. {'prompt': [[[0.3, 0.4, 0.5, 0.7]], [[0.5, 0.4, 0.7, 0.8]]]}
    """
    phrase_list = []
    for c in concepts_str.split(','):
        if c in prompt:
            phrase_list.append(c)
        elif c.replace('_', ' ') in prompt:
            phrase_list.append(c.split('_')[-1])
    phrases = '; '.join(phrase_list)
    layout_info = get_constant_layout(prompt, concepts_str)
    identifier = layout_info['objs']
    for c in concepts_str.split(','):
        identifier = identifier.replace(c, c_identifier[c])
    input_dict = {"prompt": prompt,
                "edited_prompt": edited_prompt,
                "phrases": phrases,
                "identifier": identifier,
                "bboxes": layout_info['bboxes']}
    print(input_dict)
    return input_dict


def get_constant_layout(prompt, concepts_str):
    if ' and ' in prompt:
        obj1 = prompt.split(' and ')[0].strip()
        obj2 = prompt.split(' and ')[-1].strip()
    elif 'in front of' in prompt:
        obj1 = prompt.split(' in front of ')[-1].strip()
        obj2 = prompt.split(' in front of ')[0].strip()
    else:
        obj1, obj2 = concepts_str.split(',')
    objs = []
    bboxes = []
    concepts = concepts_str.split(',')
    if len(concepts) == 1:
        objs.extend(concepts)
        bboxes.append([[0.1, 0.2, 0.5, 0.8]])
        for word in obj2.split(' '):
            objs.append(word)
            bboxes.append([[0.6, 0.2, 0.95, 0.8]])
    elif len(concepts) == 2:
        objs.append(obj1.replace(' ', '_'))
        bboxes.append([[0.1, 0.2, 0.5, 0.8]])
        objs.append(obj2.replace(' ', '_'))
        bboxes.append([[0.6, 0.2, 0.95, 0.8]])
    objs = '; '.join(objs)
    return {'objs': objs, 'bboxes': bboxes}


def generate_images(args, pipe, c_p, c_identifier):
    """
    Generate images from the optimization-based model
    params:
        args: argparse.Namespace
        pipe: StableDiffusionPipeline
        c_p: dict, concepts and prompts, e.g. {'cat': ['a photo of a cat']}
        c_identifier: dict, concepts and identifiers, e.g. {'cat': '<cute-cat>'}
    """
    for concepts_str, prompts in c_p.items():
        try:
            pipe = load_trained_weights(args, pipe, args.checkpoint, concepts_str, c_identifier)
        except Exception as e:
            print(e)
            print("Loading unique identifier failed, re-initializing the model...")
            del pipe
            gc.collect()
            torch.cuda.empty_cache()
            pipe = init_generative_model(args, args.device)
            pipe = load_trained_weights(args, pipe, args.checkpoint, concepts_str, c_identifier)

        for prompt in prompts:
            edited_prompt = edit_original_prompt(prompt, c_identifier, concepts_str.split(','), mode=args.edit_mode)
            print("Generating images for prompt: ", edited_prompt)
            generator = torch.manual_seed(8888)
            if args.spatial_inversion:
                input_dict = get_input_dict(prompt, edited_prompt, concepts_str, c_identifier)
                for idx in range(args.num_per_prompt):
                    im = pipe(input_dict, num_steps=50, guidance_scale=7.5, generator=generator)[0]
                    save_img(im, prompt, os.path.join(args.img_save_dir, concepts_str), idx)
            else:
                for idx in range(args.num_per_prompt):
                    im = pipe(edited_prompt, num_inference_steps=50, guidance_scale=7.5, generator=generator).images[0]
                    save_img(im, prompt, os.path.join(args.img_save_dir, concepts_str), idx)
    print("all images saved in ", args.img_save_dir)


def check_dir(img_save_dir, concepts_str, prompt):
    """
    check dir if exists for alignment function
    """
    if os.path.exists(os.path.join(img_save_dir, concepts_str, prompt)):
        dir = os.path.join(img_save_dir, concepts_str, prompt)
    else:
        raise ValueError(f'no such dir: {os.path.join(img_save_dir, concepts_str, prompt)}')
    return dir


def eval_alignment(evaluator, c_p, img_save_dir, src_img_dir):
    """
    Evaluate the alignment between the generated images and the source images/texts
    params:
        concepts_list: list
        evaluator: CLIPEvaluator
        c_p: dict, concepts and prompts
        img_save_dir: str, where the generated images are saved
        src_img_dir: str, where the source images are saved
    returns:
        img_img_sim_mean: dict, the average similarity between the generated images and the source images
        text_img_sim_mean: dict, the average similarity between the generated images and the source texts
    """
    img_img_sim = {}   # {c_str: {c1:[s1, s2, s3], c2:[s1, s2, s3]}}
    text_img_sim = {}  # {c_str: [s1, s2, s3]}
    
    for concepts_str, prompts in c_p.items():
        print(f'evaluating {concepts_str}...')
        img_img_sim[concepts_str] = {} 
        text_img_sim[concepts_str] = [] 

        src_imgs = get_reference_images(concepts_str, src_img_dir) # {c1: [I1, I2...], c2: [I1, I2...]}
        src_img_features = {}  # {c1: [f1, f2, f3...], c2: [f1, f2, f3...}
        for concept, concept_src_imgs in src_imgs.items():
            src_img_features[concept] = [evaluator.get_image_features(src_img) for src_img in concept_src_imgs]
            img_img_sim[concepts_str][concept] = []

        for prompt in prompts:
            text_feature = evaluator.get_text_features(prompt)
            dir = check_dir(img_save_dir, concepts_str, prompt)
            for img_path in os.listdir(dir):
                img = Image.open(os.path.join(dir, img_path))
                # text alignment
                text_img_sim[concepts_str].append(2.5*evaluator.txt_to_img_similarity(img, text_features=text_feature).cpu().numpy())
                # image alignment
                for concept, concept_img_features in src_img_features.items():
                    img_img_sim[concepts_str][concept].extend([evaluator.img_to_img_similarity(img, src_img_features=src_img_feature).cpu().numpy() for src_img_feature in concept_img_features])
    
    img_img_sim_mean = {}
    text_img_sim_mean = {}
    for concepts_str, sim_list in img_img_sim.items():
        img_img_sim_mean[concepts_str] = {}
        for concept, sims in sim_list.items():
            img_img_sim_mean[concepts_str][concept] = np.mean(sims)
        text_img_sim_mean[concepts_str] = np.mean(text_img_sim[concepts_str])
        print(f'{concepts_str}: image alignment -- {img_img_sim_mean[concepts_str]} | text alignment -- {text_img_sim_mean[concepts_str]}')
    return {'img': img_img_sim_mean, 'text': text_img_sim_mean}

def eval_coco_objs(evaluator, c_p, img_save_dir):
    """
    Evaluate the alignment between the generated images and the source images/texts
    params:
        evaluator: CLIPEvaluator
        c_p: dict, concepts and prompts
        img_save_dir: str, where the generated images are saved
    returns:
        coco_coi_mean: dict, the average similarity between the generated images and the source images
    """
    coco_objs = {}
    for concepts_str, prompts in c_p.items():
        print(f'evaluating {concepts_str} coco objects...')
        coco_objs[concepts_str] = []
        for prompt in prompts:
            if os.path.exists(os.path.join(img_save_dir, concepts_str, prompt)):
                dir = os.path.join(img_save_dir, concepts_str, prompt)
            else:
                raise ValueError(f'no such dir: {os.path.join(img_save_dir, concepts_str, prompt)}')
            for img_path in os.listdir(dir):
                img = Image.open(os.path.join(dir, img_path))
                acc = evaluator.detect_objects(img, prompt)
                coco_objs[concepts_str].append(acc)
    coco_objs_coi = {}
    for concepts_str in coco_objs.keys():
        coco_objs_coi[concepts_str] = np.mean(coco_objs[concepts_str])
        print(f'{concepts_str}: coco CoI -- {coco_objs_coi[concepts_str]}')
    return {'coco_coi': coco_objs_coi}


def get_placeholders(concepts_list):
    """
    Get the placeholders for the concepts
    params:
        concepts_list: list, the concepts list
    """
    cls_identifier = {}
    for concept in concepts_list:
        cls_identifier[concept['class_prompt']] = concept['placeholder']
    return cls_identifier


def get_concept_prompts(concepts_list, num_inverted_concepts):
    """
    Get the concept prompts
    params:
        concepts_list: list, the concepts list
        num_inverted_concepts: int, the number of inverted concepts
    returns:
        c_p: dict, {concept_str: [prompts]}, e.g. {'cat_dog': ['a cat in front of a dog', 'a dog playing with a cat']}
    """
    if num_inverted_concepts == 1:
        concepts = [concept['class_prompt'] for concept in concepts_list]
        c_p = {}
        with open("./data/coco.txt", 'r') as f:
            coco_objs = f.read().split('\n')
        for concept in concepts:
            conj = 'in front of' if concept in BACKGROUND else 'and'
            c_p[concept] = []
            for obj in coco_objs:
                obj = obj.strip()
                if conj == 'and':
                    c_p[concept].append(f"{concept.replace('_', ' ')} {conj} {obj}".strip())
                else:
                    c_p[concept].append(f"{obj} {conj} {concept.replace('_', ' ')}".strip())      
        return c_p
    if num_inverted_concepts == 2:
        c_p = {}
        concepts = [concept['class_prompt'] for concept in concepts_list]
        for i in range(len(concepts)):
            for j in range(i+1, len(concepts)):
                conj = 'and'
                c_p[f'{concepts[i]},{concepts[j]}'] = []
                c_p[f'{concepts[i]},{concepts[j]}'].append(f"{concepts[i].replace('_', ' ')} {conj} {concepts[j].replace('_', ' ')}")
                c_p[f'{concepts[i]},{concepts[j]}'].append(f"{concepts[j].replace('_', ' ')} {conj} {concepts[i].replace('_', ' ')}")
    return c_p

def generate(args):
    set_seed(seed=8888)
    # prepare the prompts
    concepts_list = json.load(open(args.concepts_list_path, "r"))
    c_p = get_concept_prompts(concepts_list, args.num_inverted_concepts)   # dict: {concept_str: [prompts]}
    c_identifier = get_placeholders(concepts_list)
    # prepare the geneartive model
    pipe = init_generative_model(args, args.device)
    # generate images
    generate_images(args, pipe, c_p, c_identifier)


def evaluate(args, eval_clip_score=True, eval_coco_coi=False):
    """
    Evaluate the generated images
    params:
        args: argparse.Namespace
        eval_clip_score: bool, whether to evaluate the CLIP score (including image alignment and text-image alignment)
        eval_coco_coi: bool, whether to evaluate the coco CoI score
    """
    results = {}
    concepts_list = json.load(open(args.concepts_list_path, "r"))
    c_p = get_concept_prompts(concepts_list, args.num_inverted_concepts)   # dict: {concept_str: [prompts]}
    
    if eval_clip_score:
        # evaluate the image alignment and text-image alignment
        evaluator = CLIPEvaluator(args.device)
        clip_score = eval_alignment(evaluator, c_p, args.img_save_dir, args.src_img_dir)
        results.update(clip_score)

    if args.num_inverted_concepts==1 and eval_coco_coi==True:
        coco_coi = eval_coco_objs(evaluator, c_p, args.img_save_dir)
        results.update(coco_coi)
        
    #KID = eval_KID(ims)
    save_results(args, results)

def run_and_test(args, **kwargs):
    generate(args)
    evaluate(args, **kwargs)
    
def save_results(args, scores):
    """ 
    Save the results to a csv file
    params:
        args: argparse.Namespace
        scores: dict, {score_name: {concept_str: score}}
    """
    results = args.get_dict()
    if os.path.exists(args.results_path):
        file = pd.read_csv(args.results_path)
    else:
        file = pd.DataFrame(columns=list(results.keys()))
    
    for score_name, score_dict in scores.items():
        for concept_str, score in score_dict.items():
            if isinstance(score, dict):
                for concept, s in score.items():
                    results[f'{concept_str}_{concept}_{score_name}'] = s
            else:
                results[f'{concept_str}_{score_name}'] = score
    file = pd.concat([file, pd.DataFrame(results, index=[0])], ignore_index=True)
    file.to_csv(args.results_path, index=False)
    print("results saved in ", args.results_path)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
args = {
    "model_name": "custom_diffusion",  # textual_inversion, dreambooth, custom_diffusion
    "spatial_inversion": True,   # True or False
    "device": "cuda:0",
    "num_inverted_concepts": 1,   # 1 or 2
    "edit_mode": "insert",
    "num_per_prompt": 5,  # number of generated images per prompt
    "concepts_list_path": "./data/concepts_list_test.json",
    "checkpoint": "./snapshot/compositional_custom_diffusion/{}",   
    "src_img_dir": "./data/reference_images/",
    "results_path": "results.csv",
}
args["img_save_dir"] = "./samples/{}_spatial_{}/{}c".format(args['model_name'],args['spatial_inversion'], args['num_inverted_concepts'])
args = Dict2Class(args)
run_and_test(args, eval_clip_score=True, eval_coco_coi=True)