项目作者: haofanwang

项目描述 :
Search photos on Unsplash, support search with joint image+text queries, support attention visualization.
高级语言: Jupyter Notebook
项目地址: git://github.com/haofanwang/natural-language-joint-query-search.git


natural-language-joint-query-search

In the project, we support multiple types of query search including text-image, image-image, text2-image, and text+image-image. In order to analyze the result of retrieved images, we also support visualization of text attention. The attention of image will be supported soon!

Colab Demo

Search photos on Unsplash, support for joint image+text queries search.

Open In Colab

Attention visualization of CLIP.

Open In Colab

Usage

We follow the same environment as the CLIP project:

  1. $ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
  2. $ pip install ftfy regex tqdm

To visualize the attention of CLIP, we slightly modify the code of CLIP as mention here, so you don’t have to install CLIP via official command. An open-sourced visualization tool is used in our project, you need to clone it into this repo.

  1. $ git clone https://github.com/shashwattrivedi/Attention_visualizer.git

Download the pre-extracted image id and features of Unsplash dataset from Google Drive or just run the following commands, and put them under unsplash-dataset dir, details can be found in natural-language-image-search project.

  1. from pathlib import Path
  2. # Create a folder for the precomputed features
  3. !mkdir unsplash-dataset
  4. # Download from Github Releases
  5. if not Path('unsplash-dataset/photo_ids.csv').exists():
  6. !wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/photo_ids.csv -O unsplash-dataset/photo_ids.csv
  7. if not Path('unsplash-dataset/features.npy').exists():
  8. !wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/features.npy -O unsplash-dataset/features.npy

Example of joint query search.

  1. import torch
  2. import numpy as np
  3. import pandas as pd
  4. from PIL import Image
  5. from CLIP.clip import clip
  6. def encode_search_query(search_query):
  7. with torch.no_grad():
  8. text_encoded, weight = model.encode_text(clip.tokenize(search_query).to(device))
  9. text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
  10. return text_encoded.cpu().numpy()
  11. def find_best_matches(text_features, photo_features, photo_ids, results_count):
  12. similarities = (photo_features @ text_features.T).squeeze(1)
  13. best_photo_idx = (-similarities).argsort()
  14. return [photo_ids[i] for i in best_photo_idx[:results_count]]
  15. device = "cuda" if torch.cuda.is_available() else "cpu"
  16. model, preprocess = clip.load("ViT-B/32", device=device)
  17. photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv")
  18. photo_ids = list(photo_ids['photo_id'])
  19. photo_features = np.load("unsplash-dataset/features.npy")
  20. # text to image
  21. search_query = "Tokyo Tower at night."
  22. text_features = model.encode_search_query(search_query)
  23. best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, 5)
  24. for photo_id in best_photo_ids:
  25. print("https://unsplash.com/photos/{}/download".format(photo_id))
  26. # image to image
  27. source_image = "images/borna-hrzina-8IPrifbjo-0-unsplash.jpg"
  28. with torch.no_grad():
  29. image_feature = model.encode_image(preprocess(Image.open(source_image)).unsqueeze(0).to(device))
  30. image_feature = (image_feature / image_feature.norm(dim=-1, keepdim=True)).cpu().numpy()
  31. best_photo_ids = find_best_matches(image_feature, photo_features, photo_ids, 5)
  32. for photo_id in best_photo_ids:
  33. print("https://unsplash.com/photos/{}/download".format(photo_id))
  34. # text+text to image
  35. search_query = "red flower"
  36. search_query_extra = "blue sky"
  37. text_features = encode_search_query(search_query)
  38. text_features_extra = encode_search_query(search_query_extra)
  39. mixed_features = text_features + text_features_extra
  40. best_photo_ids = find_best_matches(mixed_features, photo_features, photo_ids, 5)
  41. for photo_id in best_photo_ids:
  42. print("https://unsplash.com/photos/{}/download".format(photo_id))
  43. # image+text to image
  44. search_image = "images/borna-hrzina-8IPrifbjo-0-unsplash.jpg"
  45. search_text = "cars"
  46. with torch.no_grad():
  47. image_feature = model.encode_image(preprocess(Image.open(search_image)).unsqueeze(0).to(device))
  48. image_feature = (image_feature / image_feature.norm(dim=-1, keepdim=True)).cpu().numpy()
  49. text_feature = encode_search_query(search_text)
  50. modified_feature = image_feature + text_feature
  51. best_photo_ids = find_best_matches(modified_feature, photo_features, photo_ids, 5)
  52. for photo_id in best_photo_ids:
  53. print("https://unsplash.com/photos/{}/download".format(photo_id))

Example of CLIP attention visualization. You can know which keywords does CLIP use to retrieve the results. To be convenient, all punctuations are removed.

  1. import torch
  2. import numpy as np
  3. import pandas as pd
  4. from PIL import Image
  5. from CLIP.clip import clip
  6. from CLIP.clip import model
  7. from Attention_visualizer.attention_visualizer import *
  8. def find_best_matches(text_features, photo_features, photo_ids, results_count):
  9. similarities = (photo_features @ text_features.T).squeeze(1)
  10. best_photo_idx = (-similarities).argsort()
  11. return [photo_ids[i] for i in best_photo_idx[:results_count]]
  12. device = "cuda" if torch.cuda.is_available() else "cpu"
  13. model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
  14. photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv")
  15. photo_ids = list(photo_ids['photo_id'])
  16. photo_features = np.load("unsplash-dataset/features.npy")
  17. search_query = "A red flower is under the blue sky and there is a bee on the flower"
  18. with torch.no_grad():
  19. text_token = clip.tokenize(search_query).to(device)
  20. text_encoded, weight = model.encode_text(text_token)
  21. text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
  22. text_features = text_encoded.cpu().numpy()
  23. best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, 5)
  24. for photo_id in best_photo_ids:
  25. print("https://unsplash.com/photos/{}/download".format(photo_id))
  26. sentence = search_query.split(" ")
  27. attention_weights = list(weight[-1][0][1+len(sentence)].cpu().numpy())[:2+len(sentence)][1:][:-1]
  28. attention_weights = [float(item) for item in attention_weights]
  29. display_attention(sentence,attention_weights)

You can also run these example on Colab via joint-query-search and clip-attention.

Example

Text-to-Image

“Tokyo tower at night.”

Search results for "Tokyo tower at night."

“People come and go on the street.”

Search results for "People come and go on the street."

Image-to-Image

A normal street view. (The left side is the source image)

Search results for a street view image

Text+Text-to-Image

“Flower” + “Blue sky”

Search results for "flower" and "blue sky"

“Flower” + “Bee”

Search results for "flower" and "bee"

Image+Text-to-Image

A normal street view + “cars”

Search results for an empty street with query "cars"

Visualization

“A woman holding an umbrella standing next to a man in a rainy day”

Search results for "A woman holding an umbrella standing next to a man in a rainy day"

“umbrella”, “standing” and “rainy” receive the most of attention.

“A red flower is under the blue sky and there is a bee on the flower”

Search results for "A red flower is under the blue sky and there is a bee on the flower"

“flower”, “sky” and “bee” receive the most of attention.

Acknowledgements

Search photos on Unsplash using natural language descriptions. The search is powered by OpenAI’s CLIP model and the Unsplash Dataset. This project is mostly based on natural-language-image-search.

This project was inspired by these projects: