-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(docs): add cheatseet for ML07 #207
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
### Transfer Learning in Machine Learning Cheat Sheet | ||
|
||
#### Introduction | ||
Transfer learning (TL) is a machine learning (ML) technique that uses a previously trained model as the base for building a new model for a different task. The model is trained using less data for the new task. TL is a popular approach in deep learning because it can train deep neural networks with less data than building a model from scratch. | ||
|
||
This cheat sheet provides concise guidance for implementing Transfer Learning in Machine Learning models via Keras and TensorFlow. | ||
|
||
#### Goals of Transfer Learning | ||
- Speed up training by initializing models with pre-learned parameters. | ||
- Improve performance on new tasks using knowledge from related tasks. | ||
- Enable effective learning with limited labeled data by transferring knowledge from large datasets. | ||
|
||
#### Strategies | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this section and then put a section called "risks of transfer learning". Think risks around Data Leakage and Poisoning with the model. Also talk about
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll research and add around 4-5 risks associated with transfer learning attack. |
||
1. **Feature Extraction**: Freeze early layers of pre-trained models and extract learned features for new tasks. | ||
2. **Fine-tuning**: Unfreeze some layers of pre-trained models and continue training with new data to adapt to specific tasks. | ||
3. **Domain Adaptation**: Adjust pre-trained models to new domains by transferring knowledge while minimizing domain shift. | ||
4. **Multi-task Learning**: Train models to perform multiple tasks simultaneously, leveraging shared representations for improved performance. | ||
|
||
#### Implementation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would put another section above implementation that breaks down each mitigation in the ML07_2023-Transfer_Learning_Attack.md doc. So if you look at the Input Validation Cheat Sheet it starts with Introduction, Goals, and then breaks down the mitigations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the review @techiemac. What mitigations should I consider here for transfer learning attack. And how much should be the content? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And do I enhance the Strategies part or completely replace it with the mitigations? |
||
```python | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Appreciate the code example! I'll defer to @shsingh but I think we should include an example of each attack in the cheat sheet. The developer in me really likes that approach and it makes this more accessable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I will have to go through the repo once and check what's the content in there :) |
||
import tensorflow as tf | ||
from tensorflow.keras.applications import ResNet50 | ||
from tensorflow.keras.layers import Dense, Flatten | ||
from tensorflow.keras.models import Model | ||
``` | ||
This code snippet contains code for TensorFlow library and its components used for building a deep learning model based on the ResNet50 architecture. | ||
|
||
```python | ||
# Load pre-trained ResNet50 model | ||
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) | ||
``` | ||
Pre-trained ResNet50 model is loaded using TensorFlow's Keras API. The pre-trained model is obtained from the ImageNet dataset, a large-scale image database used for training convolutional neural networks. | ||
```python | ||
# Freeze early layers for feature extraction | ||
for layer in base_model.layers: | ||
layer.trainable = False | ||
``` | ||
Freezing layers in a pre-trained model is beneficial when using transfer learning, where the model's knowledge learned from a source task (e.g., ImageNet classification) is transferred to a different target task (e.g., a specific image classification task). By freezing the layers, the model retains its feature extraction capabilities while allowing users to add custom layers for the target task without affecting the pre-trained weights. | ||
```python | ||
# Add custom classifier on top | ||
x = Flatten()(base_model.output) | ||
x = Dense(256, activation='relu')(x) | ||
predictions = Dense(num_classes, activation='softmax')(x) | ||
|
||
# Create new model | ||
model = Model(inputs=base_model.input, outputs=predictions) | ||
``` | ||
The new model can be trained using a suitable optimizer and loss function for the target classification task. Since the base layers are frozen, only the custom classification layers will be trained, preserving the knowledge learned by the pre-trained ResNet50 model. This approach is commonly used in transfer learning scenarios to adapt pre-trained models to new tasks efficiently. | ||
|
||
|
||
```python | ||
# Compile model | ||
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) | ||
|
||
# Train model on new data | ||
model.fit(train_data, train_labels, epochs=10, batch_size=32, validation_data=(val_data, val_labels)) | ||
``` | ||
### Best Practices | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this section since it talks about transfer learning and not attack mitigations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, got it! |
||
|
||
**Data Augmentation**: Increase diversity and generalize better, especially with limited labeled data. | ||
|
||
**Regularization**: Prevent overfitting during fine-tuning with techniques like dropout or weight decay. | ||
|
||
**Transfer Subset of Layers**: Avoid overfitting or loss of information by transferring only relevant layers. | ||
|
||
**Monitor Performance**: Continuously monitor model performance and adjust training strategy accordingly. | ||
|
||
**Experiment with Architectures**: Explore different architectures and pre-trained models for best performance. | ||
|
||
### Conclusion | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not really needed for a cheatsheet. The persona of the cheetsheet is going to download it and then reference relevant sections. It's not really going to be treated as a doc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I'll make the changes. |
||
Transfer Learning enables efficient model training and improved performance on new tasks. By following best practices and experimenting with different techniques, practitioners can effectively apply transfer learning to various machine learning tasks. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really an overview of transfer learning. I would lift the language from the transfer learning attack doc: https://github.com/OWASP/www-project-machine-learning-security-top-10/blob/master/docs/ML07_2023-Transfer_Learning_Attack.md
Keep a bit of the overview since this is complex space and it helps the reader better understand. I think also adding additional language around how fine-tuning is a transfer learning techniques. I would argue fine-tuning is probably the more commonly used terms.
Under the lens of LLMs (I know that's not our Top 10), this is becoming one of the more commonly used techniques.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, got it!