Skip to content

Commit

Permalink
Display model sizes
Browse files Browse the repository at this point in the history
Signed-off-by: Fred Bricon <fbricon@gmail.com>
  • Loading branch information
fbricon committed Sep 23, 2024
1 parent 87a8d47 commit 442187e
Show file tree
Hide file tree
Showing 11 changed files with 712 additions and 188 deletions.
67 changes: 38 additions & 29 deletions src/configureAssistant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import * as path from "path";
import * as vscode from "vscode";

export interface AiAssistantConfigurationRequest {
chatModelName: string;
tabModelName: string;
embeddingsModelName: string;
chatModelName: string | null;
tabModelName: string | null;
embeddingsModelName: string | null;
inferenceEndpoint?: string;
provider?: string;
systemMessage?: string;
Expand Down Expand Up @@ -65,37 +65,46 @@ export class AiAssistantConfigurator {
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
const models: Model[] = config.models === undefined ? [] : config.models;
// check if model object is already in the config json
const existing = models.find(
(m) => model.provider === m.provider && model.apiBase === m.apiBase
);
let updateConfig = false;
if (existing) {
if (existing.model !== model.model || existing.title !== model.title) {
existing.model = model.model;
existing.title = model.title;
if (this.request.chatModelName) {
const existing = models.find(
(m) => model.provider === m.provider && model.apiBase === m.apiBase
);
if (existing) {
if (existing.model !== model.model || existing.title !== model.title) {
existing.model = model.model;
existing.title = model.title;
updateConfig = true;
}
} else {
models.push(model);
updateConfig = true;
}
} else {
models.push(model);
updateConfig = true;
config.models = models;
}
config.models = models;
const tabAutocompleteModel: TabAutocompleteModel = {
title: this.request.tabModelName,
model: this.request.tabModelName,
provider: this.request.provider,
};
if (config.tabAutocompleteModel !== tabAutocompleteModel) {
config.tabAutocompleteModel = tabAutocompleteModel;
updateConfig = true;
// Configure tab autocomplete model if it exists
if (this.request.tabModelName) {
const tabAutocompleteModel: TabAutocompleteModel = {
title: this.request.tabModelName,
model: this.request.tabModelName,
provider: this.request.provider,
};
if (config.tabAutocompleteModel !== tabAutocompleteModel) {
config.tabAutocompleteModel = tabAutocompleteModel;
updateConfig = true;
}
}
const embeddingsProvider = {
provider: 'ollama',
model: this.request.embeddingsModelName,
};
if (config.embeddingsProvider !== embeddingsProvider) {
config.embeddingsProvider = embeddingsProvider;
updateConfig = true;

// Configure embeddings model if it exists
if (this.request.embeddingsModelName) {
const embeddingsProvider = {
provider: 'ollama',
model: this.request.embeddingsModelName,
};
if (config.embeddingsProvider !== embeddingsProvider) {
config.embeddingsProvider = embeddingsProvider;
updateConfig = true;
}
}
if (updateConfig) {
await writeConfig(configFile, config);
Expand Down
9 changes: 5 additions & 4 deletions src/modelServer.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import { ProgressData } from "./commons/progressData";

export interface IModelServer {
name: string;
getName(): string;
isServerInstalled(): Promise<boolean>;
startServer(): Promise<boolean>;
installServer(mode: string): Promise<boolean>;
isModelInstalled(modelName: string): Promise<boolean>;
installModel(modelName: string, reportProgress: (progress: ProgressData) => void): Promise<any>;
supportedInstallModes(): Promise<{ id: string; label: string }[]>; //manual, script, homebrew
configureAssistant(
chatModelName: string,
tabModelName: string,
embeddingsModel: string
chatModelName: string | null,
tabModelName: string | null,
embeddingsModel: string | null
): Promise<void>;
listModels(): Promise<string[]>;
}
13 changes: 9 additions & 4 deletions src/ollama/ollamaServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ const PLATFORM = os.platform();
const OLLAMA_URL = "http://localhost:11434";

export class OllamaServer implements IModelServer {
name!: "Ollama";

constructor(private name: string = "Ollama") { }

getName(): string {
return this.name;
}

async supportedInstallModes(): Promise<{ id: string; label: string }[]> {
const modes = [];
Expand Down Expand Up @@ -133,9 +138,9 @@ export class OllamaServer implements IModelServer {
}

async configureAssistant(
chatModelName: string,
tabModelName: string,
embeddingsModelName: string
chatModelName: string | null,
tabModelName: string | null,
embeddingsModelName: string | null
): Promise<void> {
const request = {
chatModelName,
Expand Down
102 changes: 54 additions & 48 deletions src/panels/setupGranitePage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import {
window,
} from "vscode";
import { ProgressData } from "../commons/progressData";
import { IModelServer } from "../modelServer";
import { OllamaServer } from "../ollama/ollamaServer";
import { IModelServer } from '../modelServer';
import { OllamaServer } from '../ollama/ollamaServer';
import { getNonce } from "../utilities/getNonce";
import { getUri } from "../utilities/getUri";

Expand All @@ -26,12 +26,20 @@ import { getUri } from "../utilities/getUri";
* - Setting the HTML (and by proxy CSS/JavaScript) content of the webview panel
* - Setting message listeners so data can be passed between the webview and extension
*/

type GraniteConfiguration = {
tabModelId: string | null;
chatModelId: string | null;
embeddingsModelId: string | null;
};


export class SetupGranitePage {
public static currentPanel: SetupGranitePage | undefined;
private readonly _panel: WebviewPanel;
private _disposables: Disposable[] = [];
private _fileWatcher: fs.FSWatcher | undefined;

private server: IModelServer;
/**
* The HelloWorldPanel class private constructor (called only from the render method).
*
Expand All @@ -40,7 +48,7 @@ export class SetupGranitePage {
*/
private constructor(panel: WebviewPanel, extensionUri: Uri, extensionMode: ExtensionMode) {
this._panel = panel;

this.server = new OllamaServer();
// Set an event listener to listen for when the panel is disposed (i.e. when the user closes
// the panel or when the panel is closed programmatically)
this._panel.onDidDispose(() => this.dispose(), null, this._disposables);
Expand Down Expand Up @@ -205,7 +213,6 @@ export class SetupGranitePage {
private debounceStatus = 0;

private _setWebviewMessageListener(webview: Webview) {
const server = new OllamaServer();

webview.onDidReceiveMessage(
async (message: any) => {
Expand All @@ -217,12 +224,12 @@ export class SetupGranitePage {
webview.postMessage({
command: "init",
data: {
installModes: await server.supportedInstallModes(),
installModes: await this.server.supportedInstallModes(),
},
});
break;
case "installOllama":
await server.installServer(data.mode);
await this.server.installServer(data.mode);
break;
case "fetchStatus":
const now = new Date().getTime();
Expand All @@ -241,18 +248,18 @@ export class SetupGranitePage {
// console.log("Received fetchStatus msg " + debounceStatus);
let models: string[];
try {
models = await server.listModels();
models = await this.server.listModels();
ollamaInstalled = true;
} catch (e) {
//TODO check error response code instead?
models = [];
if (!ollamaInstalled) {
//fall back to checking CLI
ollamaInstalled = await server.isServerInstalled();
ollamaInstalled = await this.server.isServerInstalled();
if (ollamaInstalled) {
try {
await server.startServer();
models = await server.listModels();
await this.server.startServer();
models = await this.server.listModels();
} catch (e) {}
}
}
Expand Down Expand Up @@ -283,7 +290,7 @@ export class SetupGranitePage {
},
});
try {
await setupGranite(data as GraniteConfiguration, reportProgress);
await this.setupGranite(data as GraniteConfiguration, reportProgress);
} finally {
webview.postMessage({
command: "page-update",
Expand All @@ -300,49 +307,48 @@ export class SetupGranitePage {
);
}

}

type GraniteConfiguration = {
tabModelId: string;
chatModelId: string;
embeddingsModelId: string;
};




async function setupGranite(
graniteConfiguration: GraniteConfiguration, reportProgress: (progress: ProgressData) => void): Promise<void> {
async setupGranite(
graniteConfiguration: GraniteConfiguration, reportProgress: (progress: ProgressData) => void): Promise<void> {
//TODO handle continue (conflicting) onboarding page

console.log("Starting Granite Code AI-Assistant...");
const modelServer: IModelServer = new OllamaServer();
console.log("Starting Granite Code AI-Assistant...");

//collect all unique models to install, from graniteConfiguration
const modelsToInstall = new Set([graniteConfiguration.chatModelId, graniteConfiguration.tabModelId, graniteConfiguration.embeddingsModelId]);
//collect all unique models to install, from graniteConfiguration
const modelsToInstall = new Set<string>();
if (graniteConfiguration.chatModelId !== null) {
modelsToInstall.add(graniteConfiguration.chatModelId);
}
if (graniteConfiguration.tabModelId !== null) {
modelsToInstall.add(graniteConfiguration.tabModelId);
}
if (graniteConfiguration.embeddingsModelId !== null) {
modelsToInstall.add(graniteConfiguration.embeddingsModelId);
}

try {
for (const model of modelsToInstall) {
if (await modelServer.isModelInstalled(model)) {
console.log(`${model} is already installed`);
} else {
await modelServer.installModel(model, reportProgress);
try {
for (const model of modelsToInstall) {
if (await this.server.isModelInstalled(model)) {
console.log(`${model} is already installed`);
} else {
await this.server.installModel(model, reportProgress);
}
}
}

modelServer.configureAssistant(
graniteConfiguration.chatModelId,
graniteConfiguration.tabModelId,
graniteConfiguration.embeddingsModelId
);
} catch (error) {
//if error is CancellationError, then we can ignore it
if (error instanceof CancellationError) {
return;
this.server.configureAssistant(
graniteConfiguration.chatModelId,
graniteConfiguration.tabModelId,
graniteConfiguration.embeddingsModelId
);
} catch (error) {
//if error is CancellationError, then we can ignore it
if (error instanceof CancellationError) {
return;
}
throw error;
}
throw error;
}

commands.executeCommand("continue.continueGUIView.focus");
commands.executeCommand("continue.continueGUIView.focus");
}

}

2 changes: 1 addition & 1 deletion webviews/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Hello World</title>
<title>Granite Code Wizard</title>
</head>
<body>
<div id="root"></div>
Expand Down
Loading

0 comments on commit 442187e

Please sign in to comment.