Generating image embeddings on a GPU with LLaVA and llama-cpp-python

Introduction

Everyone nowadays (well, everyone who's experimented with LLMs) knows about text embeddings, which is, after tokenization, a second stage of an LLM processing some text. The cool part about embeddings is that the output of an embedding layer is a vector that's semantically meaningful. So, if I have two embedding vectors that are close to each other, the texts that produced them are also similar to each other.

In Python, with the llama-cpp-python library that uses the llama.cpp library, it's simple enough to generate a text embedding:

from llama_cpp import Llama
import numpy as np

def get_text_embedding(llm: Llama, text: str) -> np.array:
    embed = np.array(llm.embed(text))
    llm.reset()
    return embed

llm = Llama(
  model_path="../llava/ggml-model-q5_k.gguf",
  n_ctx=2048,
  n_batch=1024,
  logits_all=True,
  offload_kqv=True,
  n_gpu_layers=64,
  embedding=True,
  verbose=True
)

embedding = get_text_embedding(llm, "what is going on")

Well, the not simple part was getting llama.cpp to compile on my Windows machine with CUDA support so that this can run on my GPU instead of a CPU. That's a whole blog post in of itself and maybe I'll write it someday.

The model (ggml-model-q5_k.gguf) is from https://huggingface.co/mys/ggml_llava-v1.5-7b because this is really a post about LLaVA.

LLaVA

Now, you can do something similar with images. The CLIP vision encoder can take an image and turn it into an embedding (a vector) that represents its semantic meaning. Based on this, some researchers built LLaVA by essentially wiring the vision encoder to a large language model (Vicuna):

There are two moving parts here: the projection matrix and the language model itself.

What is the projection matrix? Well, since the encoded vector is expressed in terms of CLIP's "semantic space", we need to "translate" it to LLM's space. We do that by multiplying that vector by a projection matrix (this is also called "feature alignment"). The first stage of the training is about finding a projection matrix that does this the best. They do it by using a subset of the CC3M dataset of images and captions. My understanding is, it works like this:

  • Encode all images (you now have a set of image embedding vectors)
  • Embed all image descriptions using Vicuna (you now have a set of text embedding vectors)
  • Optimize the matrix W such that using it to transform the CLIP-encoded image vectors brings them as close to the embedded descriptions as possible.

The second stage of the training is about fine-tuning both the matrix and the language model. The training data for this phase is 158k sets of instructions, images and responses (for example, "What potential risks could this scenario indicate?" - "The scenario in the image shows one airplane flying over the airport runway with its landing gear down while another plane is parked on the runway. This situation might indicate potential risks related to air traffic control and airport runway management. [...]").

We add the encoded and projected image and the instruction to the model's context and train it to produce the expected response.

Getting the image embedding with llama-cpp-python

So, it looks like going through this process makes the basic CLIP embeddings "smarter": we not only train a matrix to project the original vector into LLM's space, but also train it in the second stage together with the rest of the LLM to give expected language responses.

It would be useful to get our hands on that embedding. Besides being able to do semantic image search (by taking an embedding of the query and finding the image vectors that are closest to it based on cosine similarity), it would also enable things like discovering image cluster or automatic tagging.

There's a recent blog post doing something similar, but it does it in a different way, by using LLaVA with a text prompt to generate the description of the image and then using an LLM to generate an embedding of that description. I felt like that would discard some of the information in the original image embedding and would take more time (instead of just processing the image into an embedding, we'd do that and process the instruction, run the LLM to get the response, process the response into an embedding again).

First, let's get the model. I downloaded it from a link on the original llama.cpp pull request adding LLaVA support

There are two files:

  • ggml-model-q5_k.gguf: the actual LLM
  • mmproj-model-f16.gguf: the CLIP encoder and the projection matrix

After some digging around the llama.cpp and llama-cpp-python codebase, I managed to come up with this code snippet:

from llama_cpp import Llama
from llama_cpp.llama_chat_format import Llava15ChatHandler
import ctypes
import array
import numpy as np

def get_image_embedding(llm: Llama, llava: Llava15ChatHandler, path: str) -> np.array:
    # Important, otherwise embeddings seem to leak across different
    # invocations of get_image_embedding
    llm.reset()
    llm._ctx.kv_cache_clear()

    image_bytes = llava.load_image(path)

    data_array = array.array("B", image_bytes)
    c_ubyte_ptr = (ctypes.c_ubyte * len(data_array)).from_buffer(data_array)

    t_start = time.time_ns()

    embed = llava._llava_cpp.llava_image_embed_make_with_bytes(
        ctx_clip=llava.clip_ctx,
        image_bytes=c_ubyte_ptr,
        n_threads=6,
        image_bytes_length=len(image_bytes),
    )

    t_embed_llava = time.time_ns()

    n_past = ctypes.c_int(llm.n_tokens)
    n_past_p = ctypes.pointer(n_past)

    # Write the image represented by embed into the llama context
    llava._llava_cpp.llava_eval_image_embed(
        ctx_llama=llm.ctx,
        embed=embed,
        n_batch=llm.n_batch,
        n_past=n_past_p,
    )

    t_eval = time.time_ns()

    print(n_past.value)

    assert llm.n_ctx() >= n_past.value
    llm.n_tokens = n_past.value
    llava._llava_cpp.llava_image_embed_free(embed)

    # Get the embedding out of the LLM
    embedding = np.array(
        llama_cpp.llama_get_embeddings(llm._ctx.ctx)[
            : llama_cpp.llama_n_embd(llm._model.model)
        ]
    )

    t_embed_llm = time.time_ns()

    print(
        f"Total: {float(t_embed_llm - t_start) / 1e6:.2f}ms, LLaVA embed: {float(t_embed_llava - t_start) / 1e6:.2f}ms, LLaMA eval: {float(t_eval - t_embed_llava) / 1e6:.2f}ms, LLaMA embed: {float(t_embed_llm - t_eval) / 1e6:.2f}ms"
    )

    llm.reset()
    return embedding

chat_handler = Llava15ChatHandler(clip_model_path="../llava/mmproj-model-f16.gguf", verbose=True)
llm = Llama(
  model_path="../llava/ggml-model-q5_k.gguf",
  chat_handler=chat_handler,
  n_ctx=2048,
  n_batch=1024,
  logits_all=True,
  n_threads=6,
  offload_kqv=True,
  n_gpu_layers=64,
  embedding=True,
  verbose=True
)

picture_embed = get_image_embedding(llm, chat_handler, "file:///....jpg")

Experiment 1: image-to-image similarity

Let's take it for a spin and get some fresh memes from imgur.com:

cat

football (not that fresh but good)

firetruck

microsoft

Compute the embeddings and get their cosine similarity matrix:

urls = {
    "cat": "https://i.imgur.com/Iden4uN.png",
    "football": "https://i.imgur.com/Uz8gkcJ.jpeg",
    "firetruck": "https://i.imgur.com/st51tjm.png",
    "microsoft": "https://i.imgur.com/hPYEf5a.jpeg",
}

def get_similarity_matrix(left: np.array, right: np.array):
    return np.dot(left, right.T) / (
        np.linalg.norm(left, axis=1)[:, np.newaxis]
        * np.linalg.norm(right, axis=1)[np.newaxis, :]
    )

embeddings = {
    name: get_image_embedding(llm, chat_handler, url) for name, url in urls.items()
}
image_embedding_matrix = np.array([embeddings[n] for n in sorted(urls.keys())])
image_sim = get_similarity_matrix(image_embedding_matrix, image_embedding_matrix)

(this takes about 300ms/image on my GPU)

In this matrix, the position at row, col denotes how similar the image row is to the image col. Let's visualize it:

from matplotlib import pyplot as plt

def plot_similarity_matrix(
    sim: np.array, x_labels: list[str], y_labels: list[str], **kwargs
):
    fig = plt.figure(**kwargs)
    ax = plt.gca()
    im = ax.imshow(sim, cmap="Wistia")

    ax.set_xticks(np.arange(len(x_labels)), labels=x_labels)
    ax.set_yticks(np.arange(len(y_labels)), labels=y_labels)
    for i in range(len(x_labels)):
        for j in range(len(y_labels)):
            ax.text(i, j, f"{sim[j, i]:.2f}", ha="center", va="center", color="k")
    fig.tight_layout()

plot_similarity_matrix(
    image_sim, list(sorted(urls.keys())), list(sorted(urls.keys())), figsize=(5, 3)
)

Interestingly enough, it thinks that the Microsoft meme is very similar (0.89 cosine similarity) to the cat meme. Perhaps it's because they use the same font and a similar layout?

Experiment 2: query-to-image similarity

Now let's come up with some search queries that either describe what's happening in the image, mention the text in the image or try and describe the image in a roundabout way:

queries = [
    "funniest meme",
    "picture of men walking",
    "4chan greentext",
    "are you fucking sorry",
    "apple competitors",
    "cartoon animal lying down",
    "woman comforting man",
    "microsoft",
]

text_embeddings = {q: get_text_embedding(llm, q) for q in queries}
text_embedding_matrix = np.array([text_embeddings[q] for q in queries])
text_sim = get_similarity_matrix(image_embedding_matrix, text_embedding_matrix)
plot_similarity_matrix(text_sim.T, list(sorted(urls.keys())), queries, figsize=(6, 6))

This is kind of fun.

  • "Are you fucking sorry" is close to the actual 4chan greentext with that text verbatim as well as the firetruck meme with the woman comforting a man (I guess he is sorry?).
  • The firetruck meme is the funniest one, according to the LLM.
  • The actual 4chan greentext isn't identified as a 4chan greentext.
  • The Microsoft meme doesn't seem to get hit much. It doesn't match "Microsoft" (which is a string that appears in the image) or "picture of men walking". In fact, it matches "4chan greentext" the most, which is bizarre.

Conclusion

So maybe it doesn't grasp the deep hidden meaning of memes, but it's still pretty cool for a model that can run on a consumer GPU and process several images a second. There's a few other things I want to experiment with. For example, using grammar-based sampling with a JSONSchema to make the LLM describe an image but give a very speficic, computer-parseable response. Imagine wiring this to a camera and getting an alert when someone is at the door, or computing how many cats it sees on the street per day — but without sending data to someone else's server.

Yeah, it's been four years since my last blog post. Since then, I ran and sold a startup and didn't do much else. It's nice to have a life again.