Skip to content

Commit

Permalink
Add cosine_distance for sparse vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
mosabua committed Jan 10, 2025
1 parent 3431609 commit 74e4109
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,24 @@ public static Double cosineSimilarity(
return dotProduct / (normLeftMap * normRightMap);
}

@Description("Calculates the cosine distance between the give sparse vectors")
@ScalarFunction
@SqlType(StandardTypes.DOUBLE)
public static double cosineDistance(
@OperatorDependency(
operator = IDENTICAL,
argumentTypes = {"varchar", "varchar"},
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionIsIdentical varcharIdentical,
@OperatorDependency(
operator = HASH_CODE,
argumentTypes = "varchar",
convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) BlockPositionHashCode varcharHashCode,
@SqlType("map(varchar,double)") SqlMap leftMap,
@SqlType("map(varchar,double)") SqlMap rightMap)
{
return 1.0 - cosineSimilarity(varcharIdentical, varcharHashCode, leftMap, rightMap);
}

private static double mapDotProduct(BlockPositionIsIdentical varcharIdentical, BlockPositionHashCode varcharHashCode, SqlMap leftMap, SqlMap rightMap)
{
int leftRawOffset = leftMap.getRawOffset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3457,6 +3457,25 @@ public void testCosineSimilarity()
.isNull(DOUBLE);
}

@Test
public void testCosineDistance()
{
assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, 2.0E0])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isEqualTo(1 - (2 * 3 / (Math.sqrt(5) * Math.sqrt(10))));

assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b', 'c'], ARRAY[1.0E0, 2.0E0, -1.0E0])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isEqualTo(1 - ((2 * 3 + -1 * 1) / (Math.sqrt(1 + 4 + 1) * Math.sqrt(1 + 9))));

assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b', 'c'], ARRAY[1.0E0, 2.0E0, -1.0E0])", "map(ARRAY['d', 'e'], ARRAY[1.0E0, 3.0E0])"))
.isEqualTo(1.0);

assertThat(assertions.function("cosine_distance", "null", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isNull();

//assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, null])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
// .isNull();
}

@Test
public void testInverseNormalCdf()
{
Expand Down
10 changes: 10 additions & 0 deletions docs/src/main/sphinx/functions/math.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ SELECT cosine_distance(ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]);
```
:::

:::{function} cosine_distance(x, y) -> double
:no-index:
Calculates the cosine distance between two sparse vectors:

```sql
SELECT cosine_distance(MAP(ARRAY['a'], ARRAY[1.0]), MAP(ARRAY['a'], ARRAY[2.0]));
-- 0.0
```
:::

:::{function} cosine_similarity(array(double), array(double)) -> double
Calculates the cosine similarity of two dense vectors:

Expand Down

0 comments on commit 74e4109

Please sign in to comment.