Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bbrt] Add retry/edit/cut buttons #4120

Merged
merged 4 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 94 additions & 27 deletions packages/bbrt/src/components/chat-message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,24 @@
import { SignalWatcher } from "@lit-labs/signals";
import { LitElement, css, html, nothing, svg } from "lit";
import { customElement, property } from "lit/decorators.js";
import { ForkEvent } from "../llm/fork.js";
import { CutEvent, EditEvent, ForkEvent, RetryEvent } from "../llm/events.js";
import type { ReactiveTurnState } from "../state/turn.js";
import { iconButtonStyle } from "../style/icon-button.js";
import { connectedEffect } from "../util/connected-effect.js";
import "./markdown.js";
import "./tool-call.js";

export interface TurnInfo {
turn: ReactiveTurnState;
index: number;
numTurnsTotal: number;
hideIcon: boolean;
}

@customElement("bbrt-chat-message")
export class BBRTChatMessage extends SignalWatcher(LitElement) {
@property({ type: Object })
accessor turn: ReactiveTurnState | undefined = undefined;

@property({ type: Boolean })
accessor hideIcon = false;
@property({ attribute: false })
accessor info: TurnInfo | undefined = undefined;

static override styles = [
iconButtonStyle,
Expand Down Expand Up @@ -99,6 +103,7 @@ export class BBRTChatMessage extends SignalWatcher(LitElement) {
width: min-content;
position: relative;
margin: 4px auto 12px -10px;
display: flex;
}
:host(:hover) #actions,
:host(:focus) #actions,
Expand All @@ -111,9 +116,20 @@ export class BBRTChatMessage extends SignalWatcher(LitElement) {
}
#actions button {
border: none;
--bb-icon: var(--bb-icon-fork-down-right);
--bb-icon-size: 20px;
}
#editButton {
--bb-icon: var(--bb-icon-edit);
}
#retryButton {
--bb-icon: var(--bb-icon-refresh);
}
#forkButton {
--bb-icon: var(--bb-icon-fork-down-right);
}
#cutButton {
--bb-icon: var(--bb-icon-content-cut);
}
#actions button:not(:hover) {
--bb-button-background: transparent;
}
Expand All @@ -123,21 +139,20 @@ export class BBRTChatMessage extends SignalWatcher(LitElement) {
override connectedCallback() {
super.connectedCallback();
connectedEffect(this, () =>
this.setAttribute("status", this.turn?.status ?? "pending")
this.setAttribute("status", this.info?.turn.status ?? "pending")
);
}

override render() {
// return html`<pre>${JSON.stringify(this.turn?.data ?? {}, null, 2)}</pre>`;
if (!this.turn) {
if (!this.info) {
return nothing;
}
return [
this.#roleIcon,
html`
<div part="contents">
<bbrt-markdown
.markdown=${this.turn.partialText}
.markdown=${this.info.turn.partialText}
part="content"
></bbrt-markdown>
${this.#renderFunctionCalls()}
Expand All @@ -148,54 +163,106 @@ export class BBRTChatMessage extends SignalWatcher(LitElement) {
}

#renderFunctionCalls() {
const calls = this.turn?.partialFunctionCalls;
const calls = this.info?.turn.partialFunctionCalls;
if (!calls?.length) {
return nothing;
}
return html`<div id="toolCalls" part="content">
${calls.map((call) =>
call.render
? call.render()
: html`<bbrt-tool-call .toolCall=${call}></bbrt-tool-call>`
${calls.map(
(call) => html`<bbrt-tool-call .toolCall=${call}></bbrt-tool-call>`
)}
</div>`;
}

get #roleIcon() {
if (!this.turn || this.hideIcon) {
if (!this.info || this.info.hideIcon) {
return nothing;
}
const role = this.hideIcon ? undefined : this.turn.role;
const role = this.info.hideIcon ? undefined : this.info.turn.role;
return html`<svg
aria-label="${role}"
role="img"
part="icon icon-${role} icon-${this.turn.status}"
part="icon icon-${role} icon-${this.info.turn.status}"
>
${role ? svg`<use href="/bbrt/images/${role}.svg#icon"></use>` : nothing}
</svg>`;
}

get #actions() {
if (!this.turn) {
if (!this.info) {
return nothing;
}
return html`
<div id="actions">
const buttons = [];
if (this.info.turn.role === "user") {
buttons.push(html`
<button
id="editButton"
class="bb-icon-button"
title="Edit"
@click=${this.#onClickEditButton}
></button>
<button
id="retryButton"
class="bb-icon-button"
title="Retry"
@click=${this.#onClickRetryButton}
></button>
`);
}
if (
this.info.turn.role === "model" &&
this.info.turn.status === "done" &&
// If the turn has function calls, we don't allow splicing, because it is
// expected that there is always another turn to come, so the cut should
// happen there instead.
this.info.turn.partialFunctionCalls.length === 0 &&
// No reason to splice if we're already at the end.
this.info.index < this.info.numTurnsTotal - 1
) {
buttons.push(html`
<button
id="cutButton"
class="bb-icon-button"
title="Cut"
@click=${this.#onClickCutButton}
></button>
<button
id="forkButton"
class="bb-icon-button"
title="Fork"
@click=${this.#onClickForkButton}
></button>
</div>
`;
></button>
</div>`);
}
return html`<div id="actions">${buttons}</div>`;
}

#onClickCutButton() {
if (!this.info) {
return;
}
this.dispatchEvent(new CutEvent(this.info.turn));
}

#onClickForkButton() {
if (!this.turn) {
if (!this.info) {
return;
}
this.dispatchEvent(new ForkEvent(this.info.turn));
}

#onClickRetryButton() {
if (!this.info) {
return;
}
this.dispatchEvent(new RetryEvent(this.info.turn));
}

#onClickEditButton() {
if (!this.info) {
return;
}
this.dispatchEvent(new ForkEvent(this.turn));
this.dispatchEvent(new EditEvent(this.info.turn));
}
}

Expand Down
10 changes: 6 additions & 4 deletions packages/bbrt/src/components/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,19 @@ export class BBRTChat extends SignalWatcher(LitElement) {
return this.conversation.state.turns.map(
(turn, i) =>
html`<bbrt-chat-message
.turn=${turn}
.hideIcon=${
.info=${{
turn,
// Hide the icon if the previous turn role was the same (since
// otherwise we see two of the same icons in a row, which looks
// weird).
// TODO(aomarks) Some kind of visual indication would
// actually be nice, though, because it's ambiguous sometimes if
// e.g. one turn had multiple tool calls, or there was a sequence of
// tool calls.
turn.role === turns[i - 1]?.role
}
hideIcon: turn.role === turns[i - 1]?.role,
index: i,
numTurnsTotal: turns.length,
}}
></bbrt-chat-message>`
);
}
Expand Down
74 changes: 56 additions & 18 deletions packages/bbrt/src/components/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import type { BBRTDriver } from "../drivers/driver-interface.js";
import { GeminiDriver } from "../drivers/gemini.js";
import { OpenAiDriver } from "../drivers/openai.js";
import { Conversation } from "../llm/conversation.js";
import type { ForkEvent } from "../llm/fork.js";
import type { CutEvent, ForkEvent, RetryEvent } from "../llm/events.js";
import { BREADBOARD_ASSISTANT_SYSTEM_INSTRUCTION } from "../llm/system-instruction.js";
import { IndexedDBSettingsSecrets } from "../secrets/indexed-db-secrets.js";
import type { SecretsProvider } from "../secrets/secrets-provider.js";
Expand All @@ -44,6 +44,7 @@ import { LocalStoragePersistence } from "../state/local-storage-persistence.js";
import type { Persistence } from "../state/persistence.js";
import { SessionStore } from "../state/session-store.js";
import type { ReactiveSessionState } from "../state/session.js";
import type { ReactiveTurnState } from "../state/turn.js";
import { ActivateTool } from "../tools/activate-tool.js";
import { AddNode } from "../tools/add-node.js";
import { CreateBoard } from "../tools/create-board.js";
Expand All @@ -57,6 +58,7 @@ import "./artifact-display.js";
import "./chat.js";
import "./driver-selector.js";
import "./prompt.js";
import { BBRTPrompt } from "./prompt.js";
import "./resizer.js";
import "./session-picker.js";
import "./tool-palette.js";
Expand All @@ -83,6 +85,7 @@ export class BBRTMain extends SignalWatcher(LitElement) {

readonly #leftBar = createRef();
readonly #rightBar = createRef();
readonly #prompt = createRef<BBRTPrompt>();

readonly #breadboardKits: Kit[] = [
asRuntimeKit(CoreKit),
Expand Down Expand Up @@ -295,14 +298,20 @@ export class BBRTMain extends SignalWatcher(LitElement) {
<bbrt-driver-selector
.conversation=${this.#conversation}
></bbrt-driver-selector>
<bbrt-prompt .conversation=${this.#conversation}></bbrt-prompt>
<bbrt-prompt
.conversation=${this.#conversation}
${ref(this.#prompt)}
></bbrt-prompt>
</div>

<bbrt-chat
.conversation=${this.#conversation}
.appState=${this.#appState}
.sessionStore=${this.#sessions}
@bbrt-fork=${this.#onFork}
@bbrt-retry=${this.#onRetry}
@bbrt-cut=${this.#onCut}
@bbrt-edit=${this.#onEdit}
></bbrt-chat>

<div id="left-sidebar" ${ref(this.#leftBar)}>
Expand Down Expand Up @@ -384,27 +393,37 @@ export class BBRTMain extends SignalWatcher(LitElement) {
return tools;
}

#onFork(forkEvent: ForkEvent) {
if (!this.#sessionState || !this.#appState) {
#onCut(event: CutEvent) {
if (!this.#sessionState) {
return;
}
const sessionEvents = this.#sessionState.events;
let sessionEventIndex = -1;
// TODO(aomarks) Unnecessary O(n). Messages should instead know their event
// index and include it when they dispatch a ForkEvent.
for (let i = 0; i < sessionEvents.length; i++) {
const sessionEvent = sessionEvents[i]!;
if (
sessionEvent.detail.kind === "turn" &&
sessionEvent.detail.turn === forkEvent.turn
) {
sessionEventIndex = i;
break;
}
this.#sessionState.rollback(this.#findEventIndexForTurn(event.turn) + 1);
}

#onRetry(event: RetryEvent) {
if (!this.#sessionState) {
return;
}
if (sessionEventIndex === -1) {
this.#sessionState.rollback(this.#findEventIndexForTurn(event.turn));
this.#conversation?.send(event.turn.partialText);
}

#onEdit(event: RetryEvent) {
const prompt = this.#prompt.value;
if (!this.#sessionState || !prompt) {
return;
}
this.#sessionState.rollback(this.#findEventIndexForTurn(event.turn));
prompt.value = event.turn.partialText;
prompt.focus();
}

#onFork(event: ForkEvent) {
if (!this.#sessionState || !this.#appState) {
return;
}
const sessionEvents = this.#sessionState.events;
const sessionEventIndex = this.#findEventIndexForTurn(event.turn);
const forkEvents = sessionEvents.slice(0, sessionEventIndex + 1);
const appState = this.#appState;
const forkBrief = appState.createSessionBrief(
Expand All @@ -419,6 +438,25 @@ export class BBRTMain extends SignalWatcher(LitElement) {
}
});
}

#findEventIndexForTurn(turn: ReactiveTurnState) {
if (!this.#sessionState) {
return -1;
}
const sessionEvents = this.#sessionState.events;
// TODO(aomarks) Unnecessary O(n). Messages should instead know their event
// index and include it when they dispatch a ForkEvent.
for (let i = 0; i < sessionEvents.length; i++) {
const sessionEvent = sessionEvents[i]!;
if (
sessionEvent.detail.kind === "turn" &&
sessionEvent.detail.turn === turn
) {
return i;
}
}
return -1;
}
}

declare global {
Expand Down
Loading
Loading