Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| import os | |
| # Uncomment to run on cpu | |
| # os.environ["JAX_PLATFORM_NAME"] = "cpu" | |
| os.environ["WANDB_DISABLED"] = "true" | |
| os.environ['WANDB_SILENT']="true" | |
| import random | |
| import re | |
| import torch | |
| import gradio as gr | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from flax.jax_utils import replicate | |
| from flax.training.common_utils import shard, shard_prng_key | |
| from PIL import Image, ImageDraw, ImageFont | |
| from functools import partial | |
| from transformers import CLIPProcessor, FlaxCLIPModel, AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel | |
| from dalle_mini import DalleBart, DalleBartProcessor | |
| from vqgan_jax.modeling_flax_vqgan import VQModel | |
| DALLE_REPO = "dalle-mini/dalle-mini/mini-1:v0" | |
| DALLE_COMMIT_ID = None | |
| VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384" | |
| VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9" | |
| model, params = DalleBart.from_pretrained( | |
| DALLE_REPO, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False | |
| ) | |
| vqgan, vqgan_params = VQModel.from_pretrained( | |
| VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False | |
| ) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning" | |
| decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning" | |
| model_checkpoint = "nlpconnect/vit-gpt2-image-captioning" | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint) | |
| tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint) | |
| viz_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device) | |
| def captioned_strip(images, caption=None, rows=1): | |
| increased_h = 0 if caption is None else 24 | |
| w, h = images[0].size[0], images[0].size[1] | |
| img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h)) | |
| for i, img_ in enumerate(images): | |
| img.paste(img_, (i // rows * w, increased_h + (i % rows) * h)) | |
| if caption is not None: | |
| draw = ImageDraw.Draw(img) | |
| font = ImageFont.truetype( | |
| "LiberationMono-Bold.ttf", 7 | |
| ) | |
| draw.text((20, 3), caption, (255, 255, 255), font=font) | |
| return img | |
| def get_images(indices, params): | |
| return vqgan.decode_code(indices, params=params) | |
| def predict_caption(image, max_length=128, num_beams=4): | |
| image = image.convert('RGB') | |
| image = feature_extractor(image, return_tensors="pt").pixel_values.to(device) | |
| clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0] | |
| caption_ids = viz_model.generate(image, max_length = max_length)[0] | |
| caption_text = clean_text(tokenizer.decode(caption_ids)) | |
| return caption_text | |
| # model inference | |
| def p_generate( | |
| tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale | |
| ): | |
| return model.generate( | |
| **tokenized_prompt, | |
| prng_key=key, | |
| params=params, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| condition_scale=condition_scale, | |
| ) | |
| # decode image | |
| def p_decode(indices, params): | |
| return vqgan.decode_code(indices, params=params) | |
| p_get_images = jax.pmap(get_images, "batch") | |
| params = replicate(params) | |
| vqgan_params = replicate(vqgan_params) | |
| processor = DalleBartProcessor.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID) | |
| print("Initialized DalleBartProcessor") | |
| clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| print("Initialized FlaxCLIPModel") | |
| def hallucinate(prompt, num_images=8): | |
| gen_top_k = None | |
| gen_top_p = None | |
| temperature = None | |
| cond_scale = 10.0 | |
| print(f"Prompts: {prompt}") | |
| prompt = [prompt] * jax.device_count() | |
| inputs = processor(prompt) | |
| inputs = replicate(inputs) | |
| # create a random key | |
| seed = random.randint(0, 2**32 - 1) | |
| key = jax.random.PRNGKey(seed) | |
| images = [] | |
| for i in range(max(num_images // jax.device_count(), 1)): | |
| key, subkey = jax.random.split(key) | |
| encoded_images = p_generate( | |
| inputs, | |
| shard_prng_key(subkey), | |
| params, | |
| gen_top_k, | |
| gen_top_p, | |
| temperature, | |
| cond_scale, | |
| ) | |
| print(f"Encoded image {i}") | |
| # remove BOS | |
| encoded_images = encoded_images.sequences[..., 1:] | |
| # decode images | |
| decoded_images = p_decode(encoded_images, vqgan_params) | |
| decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) | |
| for decoded_img in decoded_images: | |
| img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8)) | |
| images.append(img) | |
| print(f"Finished decoding image {i}") | |
| return images | |
| def run_inference(prompt, num_roundtrips=3, num_images=1): | |
| outputs = [] | |
| for i in range(int(num_roundtrips)): | |
| images = hallucinate(prompt, num_images=num_images) | |
| image = images[0] | |
| print("Generated image") | |
| caption = predict_caption(image) | |
| print(f"Predicted caption: {caption}") | |
| output_title = f""" | |
| <font size="+3"> | |
| <b>[Roundtrip {i}]</b><br> | |
| Prompt: {prompt}<br> | |
| 馃 :<br></font>""" | |
| output_caption = f""" | |
| <font size="+3"> | |
| 馃馃挰 : {caption}<br> | |
| </font> | |
| """ | |
| outputs.append(output_title) | |
| outputs.append(image) | |
| outputs.append(output_caption) | |
| prompt = caption | |
| print("Done.") | |
| return outputs | |
| inputs = gr.inputs.Textbox(label="What prompt do you want to start with?", default="cookie monster the horror movie") | |
| # num_roundtrips = gr.inputs.Number(default=2, label="How many roundtrips?") | |
| num_roundtrips = 3 | |
| outputs = [] | |
| for _ in range(int(num_roundtrips)): | |
| outputs.append(gr.outputs.HTML(label="")) | |
| outputs.append(gr.Image(label="")) | |
| outputs.append(gr.outputs.HTML(label="")) | |
| description = """ | |
| Round trip DALL路E-mini iterates between DALL路E generation and image captioning, inspired by round trip translation! FYI: runtime is forever (~1hr or possibly longer) because the app is running on CPU. | |
| """ | |
| article = "<p style='text-align: center'>Put together by: Najoung Kim | Dall-E Mini code from flax-community/dalle-mini | Caption code from SRDdev/Image-Caption</p>" | |
| gr.Interface( | |
| fn=run_inference, | |
| inputs=[inputs], | |
| outputs=outputs, | |
| title="Round Trip DALL路E mini 馃馃攣馃馃挰", | |
| description=description, | |
| article=article, | |
| theme="default", | |
| css = ".output-image, .input-image, .image-preview {height: 256px !important} " | |
| ).launch(enable_queue=False) | |