Skip to content

Fix Issue #211 : Improved Embedding Performance by Handling Base64 Encoding #303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import com.openai.core.immutableEmptyMap
import com.openai.core.toImmutable
import com.openai.errors.OpenAIInvalidDataException
import java.util.Objects
import java.util.Optional

/** Represents an embedding vector returned by embedding endpoint. */
@NoAutoDetect
Expand All @@ -25,7 +26,7 @@ class Embedding
private constructor(
@JsonProperty("embedding")
@ExcludeMissing
private val embedding: JsonField<List<Double>> = JsonMissing.of(),
private val embedding: JsonField<EmbeddingValue> = JsonMissing.of(),
@JsonProperty("index") @ExcludeMissing private val index: JsonField<Long> = JsonMissing.of(),
@JsonProperty("object") @ExcludeMissing private val object_: JsonValue = JsonMissing.of(),
@JsonAnySetter private val additionalProperties: Map<String, JsonValue> = immutableEmptyMap(),
Expand All @@ -35,7 +36,7 @@ private constructor(
* The embedding vector, which is a list of floats. The length of vector depends on the model as
* listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
*/
fun embedding(): List<Double> = embedding.getRequired("embedding")
fun embedding(): EmbeddingValue = embedding.getRequired("embedding")

/** The index of the embedding in the list of embeddings. */
fun index(): Long = index.getRequired("index")
Expand All @@ -47,7 +48,9 @@ private constructor(
* The embedding vector, which is a list of floats. The length of vector depends on the model as
* listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
*/
@JsonProperty("embedding") @ExcludeMissing fun _embedding(): JsonField<List<Double>> = embedding
@JsonProperty("embedding")
@ExcludeMissing
fun _embedding(): JsonField<EmbeddingValue> = embedding

/** The index of the embedding in the list of embeddings. */
@JsonProperty("index") @ExcludeMissing fun _index(): JsonField<Long> = index
Expand Down Expand Up @@ -92,14 +95,21 @@ private constructor(
/** A builder for [Embedding]. */
class Builder internal constructor() {

private var embedding: JsonField<MutableList<Double>>? = null
private var embedding: JsonField<EmbeddingValue>? = null
private var index: JsonField<Long>? = null
private var object_: JsonValue = JsonValue.from("embedding")
private var additionalProperties: MutableMap<String, JsonValue> = mutableMapOf()

@JvmSynthetic
internal fun from(embedding: Embedding) = apply {
this.embedding = embedding.embedding.map { it.toMutableList() }
this.embedding =
embedding.embedding.map {
EmbeddingValue(
floatEmbedding =
Optional.of(it.floatEmbedding.orElse(mutableListOf()).toMutableList()),
base64Embedding = it.base64Embedding,
)
}
index = embedding.index
object_ = embedding.object_
additionalProperties = embedding.additionalProperties.toMutableMap()
Expand All @@ -110,27 +120,32 @@ private constructor(
* model as listed in the
* [embedding guide](https://platform.openai.com/docs/guides/embeddings).
*/
fun embedding(embedding: List<Double>) = embedding(JsonField.of(embedding))
fun embedding(embedding: EmbeddingValue) = embedding(JsonField.of(embedding))

/**
* The embedding vector, which is a list of floats. The length of vector depends on the
* model as listed in the
* The embedding vector, which is a list of floats or Base64. The float length of vector
* depends on the model as listed in the
* [embedding guide](https://platform.openai.com/docs/guides/embeddings).
*/
fun embedding(embedding: JsonField<List<Double>>) = apply {
this.embedding = embedding.map { it.toMutableList() }
fun embedding(embedding: JsonField<EmbeddingValue>) = apply {
this.embedding =
embedding.map {
EmbeddingValue(
floatEmbedding =
Optional.of(it.floatEmbedding.orElse(mutableListOf()).toMutableList()),
base64Embedding = it.base64Embedding,
)
}
}

/**
* The embedding vector, which is a list of floats. The length of vector depends on the
* model as listed in the
* The embedding vector, which is a list of floats or Base64. The float length of vector
* depends on the model as listed in the
* [embedding guide](https://platform.openai.com/docs/guides/embeddings).
*/
fun addEmbedding(embedding: Double) = apply {
fun addEmbedding(embedding: EmbeddingValue) = apply {
this.embedding =
(this.embedding ?: JsonField.of(mutableListOf())).also {
checkKnown("embedding", it).add(embedding)
}
(this.embedding ?: JsonField.of(embedding)).also { checkKnown("embedding", it) }
}

/** The index of the embedding in the list of embeddings. */
Expand Down Expand Up @@ -163,7 +178,13 @@ private constructor(

fun build(): Embedding =
Embedding(
checkRequired("embedding", embedding).map { it.toImmutable() },
checkRequired("embedding", embedding).map {
EmbeddingValue(
floatEmbedding =
Optional.of(it.floatEmbedding.orElse(mutableListOf()).toMutableList()),
base64Embedding = it.base64Embedding,
)
},
checkRequired("index", index),
object_,
additionalProperties.toImmutable(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,9 @@ private constructor(
private var input: JsonField<Input>? = null
private var model: JsonField<EmbeddingModel>? = null
private var dimensions: JsonField<Long> = JsonMissing.of()
private var encodingFormat: JsonField<EncodingFormat> = JsonMissing.of()
// Default EncodingFormat value is set to BASE64 for performance improvements.
private var encodingFormat: JsonField<EncodingFormat> =
JsonField.of(EncodingFormat.BASE64)
private var user: JsonField<String> = JsonMissing.of()
private var additionalProperties: MutableMap<String, JsonValue> = mutableMapOf()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package com.openai.models.embeddings

import com.fasterxml.jackson.databind.annotation.JsonDeserialize
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.util.Base64
import java.util.Optional
import kotlin.collections.MutableList

/** Represents an embedding vector returned by embedding endpoint. */
@JsonDeserialize(using = EmbeddingValueDeserializer::class)
class EmbeddingValue(
var base64Embedding: Optional<String> = Optional.empty(),
floatEmbedding: Optional<MutableList<Double>> = Optional.empty(),
) {

/**
* The embedding vector, which is a list of float32.
* [embedding guide](https://platform.openai.com/docs/guides/embeddings).
*/
var floatEmbedding: Optional<MutableList<Double>> = Optional.empty()
get() {
if (field.isPresent) {
return field
}
if (base64Embedding.isPresent) {
field = convertBase64ToFloat(base64Embedding)
}
return field
}
set(value) {
field = value
}

/**
* Converting Base64 float32 array to Optional<MutableList>
*
* To improve performance, requests are made in Base64 by default. However, not all developers
* need to decode Base64. Therefore, when a request is made in Base64, the system will
* internally convert the Base64 data to MutableList<Double> and make this converted data
* available, allowing developers to obtain both the Base64 data and the MutableList<Double>
* data by default.
*/
private fun convertBase64ToFloat(
base64Embedding: Optional<String>
): Optional<MutableList<Double>> {
// The response of Embedding returns a List<Float>(float32),
// but the Kotlin API handles MutableList<Double>.
// If we directly convert from List<Float> to MutableList<Double>,
// it increases the precision and changing it from float32 to double.
//
// Since JSON is assigned to MutableList<Double> from a String of JSON Value,
// the precision does not increase.
// Therefore, by first converting the Base64-decoded List<Float> to a String,
// and then converting the String to Double,
// we can handle it as MutableList<Double> without increasing the precision.
return base64Embedding.map { base64String ->
val decoded = Base64.getDecoder().decode(base64String)
val byteBuffer = ByteBuffer.wrap(decoded).order(ByteOrder.LITTLE_ENDIAN)

val floatList = mutableListOf<String>()
while (byteBuffer.hasRemaining()) {
floatList.add(byteBuffer.float.toString())
}
floatList.map { it.replace("f", "").toDouble() }.toMutableList()
}
}

/**
* Output the embedding vector as a string. By default, it will be output as both list of floats
* and Base64 string. if user specifies floatEmbedding, it will be output as list of floats
* only.
*/
override fun toString(): String {
return if (base64Embedding.isPresent) {
"base64: $base64Embedding, float: [${floatEmbedding.get().joinToString(", ")}]"
} else {
"float: [${floatEmbedding.get().joinToString(", ")}]"
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.openai.models.embeddings

import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.databind.DeserializationContext
import com.fasterxml.jackson.databind.JsonDeserializer
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.node.ArrayNode
import java.io.IOException
import java.util.Optional

/** JsonDeserializer for EmbeddingValue */
class EmbeddingValueDeserializer : JsonDeserializer<EmbeddingValue>() {
@Throws(IOException::class)

/*
* Deserialize the JSON representation of an EmbeddingValue.
* The JSON can either be an array of floats or a base64 string.
*/
override fun deserialize(jp: JsonParser, ctxt: DeserializationContext): EmbeddingValue {
val node = jp.codec.readTree<JsonNode>(jp)
val embeddingValue = EmbeddingValue()

if (node.isArray) {
val floats = mutableListOf<Double>()
(node as ArrayNode).forEach { item -> floats.add(item.asDouble()) }
embeddingValue.floatEmbedding = Optional.of(floats)
} else if (node.isTextual) {
embeddingValue.base64Embedding = Optional.of(node.asText())
}
return embeddingValue
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

package com.openai.models.embeddings

import java.util.Optional
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test

Expand All @@ -11,15 +12,35 @@ class CreateEmbeddingResponseTest {
fun createCreateEmbeddingResponse() {
val createEmbeddingResponse =
CreateEmbeddingResponse.builder()
.addData(Embedding.builder().addEmbedding(0.0).index(0L).build())
.addData(
Embedding.builder()
.addEmbedding(
EmbeddingValue(
floatEmbedding = Optional.of(mutableListOf(0.0)),
base64Embedding = Optional.empty(),
)
)
.index(0L)
.build()
)
.model("model")
.usage(
CreateEmbeddingResponse.Usage.builder().promptTokens(0L).totalTokens(0L).build()
)
.build()
assertThat(createEmbeddingResponse).isNotNull
assertThat(createEmbeddingResponse.data())
.containsExactly(Embedding.builder().addEmbedding(0.0).index(0L).build())
.containsExactly(
Embedding.builder()
.addEmbedding(
EmbeddingValue(
floatEmbedding = Optional.of(mutableListOf(0.0)),
base64Embedding = Optional.empty(),
)
)
.index(0L)
.build()
)
assertThat(createEmbeddingResponse.model()).isEqualTo("model")
assertThat(createEmbeddingResponse.usage())
.isEqualTo(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
// File generated from our OpenAPI spec by Stainless.

package com.openai.models.embeddings

import java.util.Optional
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test

class EmbeddingTest {

@Test
fun createEmbedding() {
val embedding = Embedding.builder().addEmbedding(0.0).index(0L).build()
val embedding =
Embedding.builder()
.addEmbedding(
EmbeddingValue(
floatEmbedding = Optional.of(mutableListOf(0.0)),
base64Embedding = Optional.empty(),
)
)
.build()
assertThat(embedding).isNotNull
assertThat(embedding.embedding()).containsExactly(0.0)
// assertThat(embedding.embedding()).containsExactly(0.0)
assertThat(embedding.embedding().floatEmbedding).containsSame(mutableListOf(0.0))
assertThat(embedding.index()).isEqualTo(0L)
}
}